Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 147 additions & 55 deletions cg_mapping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import mdtraj
import re
import warnings

import mdtraj
import numpy as np
from moleculekit.molecule import Molecule
from aggforce import LinearMap, project_forces #type: ignore

Expand Down Expand Up @@ -56,69 +58,139 @@ def __init__(self, topology, map_def):

self.cg_topology = mdtraj.Topology()

for chain in topology.chains:
last_backbone_idx = None
mappable_residues = set(map_def.bead_atom_selection.keys())
dna_mods = getattr(map_def, "dna_modifications", None) or {}
dna_residue_names = frozenset(["DA", "DT", "DG", "DC"])

if not any([r.is_protein for r in chain.residues]):
continue # Skip water/ligand/ion chains
def strip_resname(rname: str) -> str:
return re.sub(r"\d+$", "", rname)

result_chain = self.cg_topology.add_chain()
def clean_atom_name(name: str) -> str:
s = str(name).split("-")[-1]
return s.replace("*", "'")

def _bead_indices(imap: dict, atom_names) -> list | None:
out: list = []
for an in atom_names:
if an not in imap:
return None
out.append(imap[an])
return out

for chain in topology.chains:
if not any(
(strip_resname(r.name) in mappable_residues) or (strip_resname(r.name) in dna_mods)
for r in chain.residues
):
continue

# Initially 0 when we haven't seen a protein residue
# Becomes 1 with the first protein residue
# Then 2 if we see a non-protein residue
# This is done ensure the chain is contiguous while still allowing e.g. an ion at the end of a chain
result_chain = self.cg_topology.add_chain()
last_backbone_idx = None
chain_protein_mode = 0

for res in chain.residues:
if not res.is_protein:
res_name = strip_resname(res.name)
mapped_resname: str | None
if res_name in mappable_residues:
mapped_resname = res_name
elif res_name in dna_mods:
parent = dna_mods[res_name]
if parent in mappable_residues:
mapped_resname = str(parent)
else:
mapped_resname = None
else:
mapped_resname = None

if mapped_resname is None:
if chain_protein_mode == 1:
chain_protein_mode = 2
continue # Skip non-protein residues
continue
if chain_protein_mode == 0:
chain_protein_mode = 1
elif chain_protein_mode == 2:
warnings.warn(
f"Non-contiguous chain {chain.index}: residue {res.name} after a gap of "
f"unmappable residues",
RuntimeWarning,
stacklevel=2,
)

idx_mapping = {clean_atom_name(a.name): a.index for a in res.atoms}
bead_mapping = map_def.bead_atom_selection[mapped_resname]
backbone_idx = self.cg_topology.n_atoms
all_bead_indices: list
bbcand = getattr(map_def, "dna_backbone_atom_candidates", None)
if (
bbcand
and mapped_resname in dna_residue_names
and len(bead_mapping) >= 1
):
first: list | None = None
for alt in bbcand:
first = _bead_indices(idx_mapping, alt)
if first is not None:
break
if first is None:
continue
rest: list = []
ok = True
for bead in bead_mapping[1:]:
bidx = _bead_indices(idx_mapping, bead)
if bidx is None:
ok = False
break
rest.append(bidx)
if not ok:
continue
all_bead_indices = [first] + rest
else:
if chain_protein_mode == 0:
chain_protein_mode = 1
assert chain_protein_mode == 1

idx_mapping = {a.name: a.index for a in res.atoms}
bead_mapping = map_def.bead_atom_selection[res.name]

first_bead_idx = self.cg_topology.n_atoms
# Determine index of this residue's last backbone bead
# E.g. if the beads were N-CA-C--N-CA-C the offset would be 2 to make C the last backbone
backbone_idx = self.cg_topology.n_atoms + map_def.bead_backbone_idx[res.name]
all_bead_indices = []
valid = True
for bead in bead_mapping:
bidx = _bead_indices(idx_mapping, bead)
if bidx is None:
valid = False
break
all_bead_indices.append(bidx)
if not valid or len(all_bead_indices) != len(bead_mapping):
continue

result_res = self.cg_topology.add_residue(res.name, result_chain)

# for bead_name, bead_type, bead_mass in zip(map_def.bead_atom_name[res.name], map_def.bead_type[res.name], map_def.bead_mass[res.name]):
for bead_name in map_def.bead_atom_names[res.name]:
self.cg_topology.add_atom(bead_name, mdtraj.element.carbon, result_res)
self.bead_atom_names.extend(map_def.bead_atom_names[res.name])
self.bead_types.extend(map_def.bead_types[res.name])
self.bead_mass.extend(map_def.bead_masses[res.name])
self.embeddings.extend(map_def.bead_embeddings[res.name])

for bead in bead_mapping:
bead_idx = []
for atom in bead:
if atom not in idx_mapping:
# FIXME: The martini mappings seem to have extra atoms (possibly to handle different naming schemes?)
raise RuntimeError(f"Missing atom: {res}, {atom}")
else:
bead_idx.append(idx_mapping[atom])

for bead_name in map_def.bead_atom_names[mapped_resname]:
if bead_name == "DBB":
element = mdtraj.element.phosphorus
elif str(bead_name).startswith("DB"):
element = mdtraj.element.nitrogen
elif str(bead_name).startswith("CA"):
element = mdtraj.element.carbon
else:
fc = str(bead_name)[0].upper() if bead_name else "C"
elmap = {
"P": mdtraj.element.phosphorus,
"N": mdtraj.element.nitrogen,
"C": mdtraj.element.carbon,
}
element = elmap.get(fc, mdtraj.element.carbon)
self.cg_topology.add_atom(bead_name, element, result_res)
self.bead_atom_names.extend(map_def.bead_atom_names[mapped_resname])
self.bead_types.extend(map_def.bead_types[mapped_resname])
self.bead_mass.extend(map_def.bead_masses[mapped_resname])
self.embeddings.extend(map_def.bead_embeddings[mapped_resname])
for bead_idx in all_bead_indices:
self.src_idx.append(bead_idx)
# FIXME: Should use OpenMM masses not mdtraj's
bead_weights = np.array([topology.atom(i).element.mass for i in bead_idx])
bead_weights = (bead_weights / np.sum(bead_weights)).tolist()
self.pos_weights.append(bead_weights)
self.force_weights.append(bead_weights)

bead_w = np.array([topology.atom(i).element.mass for i in bead_idx])
bead_w = (bead_w / np.sum(bead_w)).tolist()
self.pos_weights.append(bead_w)
self.force_weights.append(bead_w)
if last_backbone_idx is not None:
# Add a backbone bond between the first bead of each residue
self.cg_topology.add_bond(self.cg_topology.atom(last_backbone_idx), self.cg_topology.atom(first_bead_idx))
# Sequential bonds from the backbone to the other beads
for i in range(len(bead_mapping)-1):
self.cg_topology.add_bond(self.cg_topology.atom(first_bead_idx+i), self.cg_topology.atom(first_bead_idx+i+1))
self.cg_topology.add_bond(
self.cg_topology.atom(last_backbone_idx), self.cg_topology.atom(backbone_idx)
)
for i in range(len(all_bead_indices) - 1):
self.cg_topology.add_bond(
self.cg_topology.atom(backbone_idx + i), self.cg_topology.atom(backbone_idx + i + 1)
)
last_backbone_idx = backbone_idx

def to_mol(self, bonds=True, angles=True, dihedrals=True):
Expand All @@ -137,7 +209,27 @@ def to_mol(self, bonds=True, angles=True, dihedrals=True):
mol.beta = np.full((self.cg_topology.n_atoms,), 0.0, dtype=np.float32) #pyright: ignore[reportAttributeAccessIssue]
mol.record = np.full((self.cg_topology.n_atoms,), 'ATOM', dtype=object) #pyright: ignore[reportAttributeAccessIssue]
mol.altloc = np.full((self.cg_topology.n_atoms,), '', dtype=object) #pyright: ignore[reportAttributeAccessIssue]
mol.element = np.full((self.cg_topology.n_atoms,), 'C', dtype=object) #pyright: ignore[reportAttributeAccessIssue]
# PSF: CG bead names; DNA DBB/DB* use P/N for writer compatibility (cgschnet)
mol.element = np.array(
[
"P" if n == "DBB" else "N" if str(n).startswith("DB") else "C" if str(n).startswith("CA") else "C"
for n in self.bead_atom_names
],
dtype=object,
) # pyright: ignore[reportAttributeAccessIssue]
mol.atomicnumber = np.array(
[15 if n == "DBB" else 7 if str(n).startswith("DB") else 6 for n in self.bead_atom_names],
dtype=np.int32,
)
disp_masses: list[float] = []
for n in self.bead_atom_names:
if n == "DBB":
disp_masses.append(30.97)
elif str(n).startswith("DB"):
disp_masses.append(14.01)
else:
disp_masses.append(12.01)
mol.masses = np.array(disp_masses, dtype=np.float32) # pyright: ignore[reportAttributeAccessIssue]
mol.formalcharge = np.full((self.cg_topology.n_atoms,), 0, dtype=np.int32) #pyright: ignore[reportAttributeAccessIssue]

# The output psf contains resname=res_abbr, name=CA, atomtype=bead_type
Expand All @@ -146,7 +238,7 @@ def to_mol(self, bonds=True, angles=True, dihedrals=True):
mol.resname = np.array([a.residue.name for a in self.cg_topology.atoms], dtype=object) #pyright: ignore[reportAttributeAccessIssue]

mol.charge = np.full((self.cg_topology.n_atoms,), 0) #pyright: ignore[reportAttributeAccessIssue]
mol.masses = np.array(self.bead_mass) #pyright: ignore[reportAttributeAccessIssue]
mol._cg_bead_masses = np.array(self.bead_mass, dtype=np.float32) # pyright: ignore[reportAttributeAccessIssue]

mol.box = np.zeros((3, 0), dtype=np.float32)

Expand Down
Loading