From 4409f540847b21af3dc6f5304d4cec126bc8211b Mon Sep 17 00:00:00 2001 From: Akshitha Date: Wed, 15 Apr 2026 12:49:20 -0700 Subject: [PATCH 1/2] separate branch for preprocess refactor, main changes in make_delta forces for step 3 parallelization --- cg_mapping.py | 8 ++ make_deltaforces.py | 173 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 154 insertions(+), 27 deletions(-) diff --git a/cg_mapping.py b/cg_mapping.py index a9b573f..bfd85fb 100644 --- a/cg_mapping.py +++ b/cg_mapping.py @@ -79,6 +79,14 @@ def __init__(self, topology, map_def): chain_protein_mode = 1 assert chain_protein_mode == 1 + if res.name not in map_def.bead_atom_selection: + warnings.warn( + f"Skipping unsupported protein residue '{res.name}' in chain {chain.index}", + RuntimeWarning, + ) + last_backbone_idx = None + continue + idx_mapping = {a.name: a.index for a in res.atoms} bead_mapping = map_def.bead_atom_selection[res.name] diff --git a/make_deltaforces.py b/make_deltaforces.py index b06e8d0..90dba68 100644 --- a/make_deltaforces.py +++ b/make_deltaforces.py @@ -1,24 +1,98 @@ -import torch +import multiprocessing as mp +import time + import numpy as np +import torch from tqdm import tqdm -# from torchmd.forcefields.forcefield import ForceField -from module.torchmd import tagged_forcefield -from torchmd.forces import Forces -from torchmd.systems import System from moleculekit.molecule import Molecule +from torchmd.forces import Forces from torchmd.parameters import Parameters -# from simulate import CalcWrapper -import time +from torchmd.systems import System + +# from torchmd.forcefields.forcefield import ForceField from module.external_nn import ExternalNN, ParametersNN +from module.torchmd import tagged_forcefield + +# from simulate import CalcWrapper # adapted from https://github.com/torchmd/torchmd-cg/blob/master/torchmd_cg/utils/make_deltaforces.py + +def _split_frame_indices(frames: list[int], n_workers: int) -> list[list[int]]: + n = len(frames) + if n_workers <= 1 or n <= 1: + return [frames] + n_workers = min(n_workers, n) + q, r = divmod(n, n_workers) + out: list[list[int]] = [] + idx = 0 + for w in range(n_workers): + take = q + (1 if w < r else 0) + out.append(frames[idx : idx + take]) + idx += take + return out + + +def _classical_prior_chunk_worker( + psf: str, + coords_npz: str, + box_npz: str | None, + forcefield: str, + exclusions: tuple, + forceterms: list, + frame_indices: list[int], +) -> tuple[list[int], np.ndarray, np.ndarray]: + torch.set_num_threads(1) + try: + import mkl # type: ignore[import-untyped] + + mkl.set_num_threads(1) + except Exception: + pass + + precision = torch.float32 + device = torch.device("cpu") + mol = Molecule(psf) + natoms = mol.numAtoms + coords = np.load(coords_npz) + box = np.load(box_npz) if box_npz else None + coords_t = torch.tensor(coords, dtype=precision) + if box is not None: + linearized = box.reshape(-1, 9).take([0, 4, 8], axis=1) + box_full = linearized.reshape(linearized.shape[0], 3, 1) + else: + box_full = torch.zeros(coords.shape[0], 3, 1) + + ff = tagged_forcefield.create(mol, forcefield) + parameters = Parameters(ff, mol, forceterms, precision=precision, device=device) # pyright: ignore[reportArgumentType] + system = System(natoms, 1, precision, device) + system.set_positions(np.zeros((natoms, 3, 1))) + system.set_velocities(torch.zeros(1, natoms, 3)) + forces = Forces(parameters, terms=forceterms, exclusions=exclusions) + + n_fr = len(frame_indices) + out_f = np.zeros((n_fr, natoms, 3), dtype=np.float32) + out_e = np.zeros((n_fr,), dtype=np.float32) + for k, i in enumerate(frame_indices): + co = coords_t[i] + system.set_box(box_full[i]) + pot = forces.compute(co.reshape([1, natoms, 3]), system.box, system.forces) + out_f[k] = system.forces.detach().cpu().reshape(natoms, 3).numpy() + p0 = pot[0] + out_e[k] = p0.item() if hasattr(p0, "item") else float(p0) + return frame_indices, out_f, out_e + + class DeltaForces: def __init__(self, device, psf, coords_npz, box_npz): self.device = torch.device(device) self.precision = torch.float32 self.replicas = 1 + self._psf_path = psf + self._coords_npz_path = coords_npz + self._box_npz_path = box_npz + self.mol = Molecule(psf) self.natoms = self.mol.numAtoms @@ -41,38 +115,83 @@ def __init__(self, device, psf, coords_npz, box_npz): self.parameters = None - def computePriorForces(self, + def computePriorForces( + self, forcefield, exclusions=("bonds"), forceterms=["Bonds", "Angles", "RepulsionCG"], - bar_position=0,frames=None + bar_position=0, + frames=None, + num_parallel_workers: int = 1, ): # if forceterms is empty list, then we exit if forceterms == []: return - ff = tagged_forcefield.create(self.mol, forcefield) - parameters = Parameters(ff, self.mol, forceterms, precision=self.precision, device=self.device) #pyright: ignore[reportArgumentType] - - system = System(self.natoms, self.replicas, self.precision, self.device) - system.set_positions(np.zeros((self.natoms, 3, self.replicas))) - system.set_velocities(torch.zeros(self.replicas, self.natoms, 3)) - - forces = Forces(parameters, terms=forceterms, exclusions=exclusions) - if frames is None: # if None, then process all frames + if frames is None: frames = range(0, self.coords.shape[0]) + frames_list = list(frames) + in_daemon = mp.current_process().daemon + effective_workers = num_parallel_workers + if in_daemon and num_parallel_workers > 1: + tqdm.write( + "Delta forces - Classical: running single-process because daemon workers cannot spawn child pools." + ) + effective_workers = 1 start_time = time.time() - for i in tqdm(frames, position=bar_position, dynamic_ncols=True, desc="Delta forces - Classical", leave=(bar_position==0)): - co = self.coords[i] - system.set_box(self.box_full[i]) - pot = forces.compute(co.reshape([1, self.natoms, 3]), system.box, system.forces) - fr = ( - system.forces.detach().cpu().reshape([self.natoms, 3]) + if effective_workers <= 1 or len(frames_list) <= 1: + ff = tagged_forcefield.create(self.mol, forcefield) + parameters = Parameters( + ff, self.mol, forceterms, precision=self.precision, device=self.device + ) # pyright: ignore[reportArgumentType] + + system = System(self.natoms, self.replicas, self.precision, self.device) + system.set_positions(np.zeros((self.natoms, 3, self.replicas))) + system.set_velocities(torch.zeros(self.replicas, self.natoms, 3)) + + forces = Forces(parameters, terms=forceterms, exclusions=exclusions) + for i in tqdm( + frames_list, + position=bar_position, + dynamic_ncols=True, + desc="Delta forces - Classical", + leave=(bar_position == 0), + ): + co = self.coords[i] + system.set_box(self.box_full[i]) + pot = forces.compute(co.reshape([1, self.natoms, 3]), system.box, system.forces) + fr = system.forces.detach().cpu().reshape([self.natoms, 3]) + self.prior_forces[i, :, :] += fr + assert len(pot) == 1 + self.prior_energies[i] += pot[0] + else: + chunks = _split_frame_indices(frames_list, effective_workers) + box_arg = self._box_npz_path + tasks = [ + ( + self._psf_path, + self._coords_npz_path, + box_arg, + forcefield, + tuple(exclusions), + list(forceterms), + chunk, + ) + for chunk in chunks + ] + ctx = mp.get_context("spawn") + with ctx.Pool(processes=len(chunks)) as pool: + results = pool.starmap(_classical_prior_chunk_worker, tasks) + for idxs, f_blk, e_blk in results: + idx_t = torch.tensor(idxs, dtype=torch.long) + self.prior_forces[idx_t, :, :] += torch.tensor(f_blk, dtype=self.precision) + self.prior_energies[idx_t] += torch.tensor(e_blk, dtype=self.precision) + tqdm.write( + f"Time taken for classical forces (parallel, {len(chunks)} workers) {time.time() - start_time:.2f}" ) - self.prior_forces[i,:,:] += fr - assert len(pot) == 1 - self.prior_energies[i] += pot[0] + return + tqdm.write(f"Time taken for classical forces {time.time() - start_time:.2f}") From e443bf723aa65761c20ab7f24bc908210576263e Mon Sep 17 00:00:00 2001 From: Akshitha Date: Sun, 26 Apr 2026 20:24:44 -0700 Subject: [PATCH 2/2] refactor changes to modules --- cg_mapping.py | 204 +++++++++++++++++++++++++++++++------------- make_deltaforces.py | 18 +--- 2 files changed, 146 insertions(+), 76 deletions(-) diff --git a/cg_mapping.py b/cg_mapping.py index bfd85fb..6c3eb2d 100644 --- a/cg_mapping.py +++ b/cg_mapping.py @@ -1,6 +1,8 @@ -import numpy as np -import mdtraj +import re import warnings + +import mdtraj +import numpy as np from moleculekit.molecule import Molecule from aggforce import LinearMap, project_forces #type: ignore @@ -56,77 +58,139 @@ def __init__(self, topology, map_def): self.cg_topology = mdtraj.Topology() - for chain in topology.chains: - last_backbone_idx = None + mappable_residues = set(map_def.bead_atom_selection.keys()) + dna_mods = getattr(map_def, "dna_modifications", None) or {} + dna_residue_names = frozenset(["DA", "DT", "DG", "DC"]) - if not any([r.is_protein for r in chain.residues]): - continue # Skip water/ligand/ion chains + def strip_resname(rname: str) -> str: + return re.sub(r"\d+$", "", rname) - result_chain = self.cg_topology.add_chain() + def clean_atom_name(name: str) -> str: + s = str(name).split("-")[-1] + return s.replace("*", "'") + + def _bead_indices(imap: dict, atom_names) -> list | None: + out: list = [] + for an in atom_names: + if an not in imap: + return None + out.append(imap[an]) + return out - # Initially 0 when we haven't seen a protein residue - # Becomes 1 with the first protein residue - # Then 2 if we see a non-protein residue - # This is done ensure the chain is contiguous while still allowing e.g. an ion at the end of a chain + for chain in topology.chains: + if not any( + (strip_resname(r.name) in mappable_residues) or (strip_resname(r.name) in dna_mods) + for r in chain.residues + ): + continue + + result_chain = self.cg_topology.add_chain() + last_backbone_idx = None chain_protein_mode = 0 + for res in chain.residues: - if not res.is_protein: - if chain_protein_mode == 1: - chain_protein_mode = 2 - continue # Skip non-protein residues + res_name = strip_resname(res.name) + mapped_resname: str | None + if res_name in mappable_residues: + mapped_resname = res_name + elif res_name in dna_mods: + parent = dna_mods[res_name] + if parent in mappable_residues: + mapped_resname = str(parent) + else: + mapped_resname = None else: - if chain_protein_mode == 0: - chain_protein_mode = 1 - assert chain_protein_mode == 1 + mapped_resname = None - if res.name not in map_def.bead_atom_selection: + if mapped_resname is None: + if chain_protein_mode == 1: + chain_protein_mode = 2 + continue + if chain_protein_mode == 0: + chain_protein_mode = 1 + elif chain_protein_mode == 2: warnings.warn( - f"Skipping unsupported protein residue '{res.name}' in chain {chain.index}", + f"Non-contiguous chain {chain.index}: residue {res.name} after a gap of " + f"unmappable residues", RuntimeWarning, + stacklevel=2, ) - last_backbone_idx = None - continue - idx_mapping = {a.name: a.index for a in res.atoms} - bead_mapping = map_def.bead_atom_selection[res.name] - - first_bead_idx = self.cg_topology.n_atoms - # Determine index of this residue's last backbone bead - # E.g. if the beads were N-CA-C--N-CA-C the offset would be 2 to make C the last backbone - backbone_idx = self.cg_topology.n_atoms + map_def.bead_backbone_idx[res.name] + idx_mapping = {clean_atom_name(a.name): a.index for a in res.atoms} + bead_mapping = map_def.bead_atom_selection[mapped_resname] + backbone_idx = self.cg_topology.n_atoms + all_bead_indices: list + bbcand = getattr(map_def, "dna_backbone_atom_candidates", None) + if ( + bbcand + and mapped_resname in dna_residue_names + and len(bead_mapping) >= 1 + ): + first: list | None = None + for alt in bbcand: + first = _bead_indices(idx_mapping, alt) + if first is not None: + break + if first is None: + continue + rest: list = [] + ok = True + for bead in bead_mapping[1:]: + bidx = _bead_indices(idx_mapping, bead) + if bidx is None: + ok = False + break + rest.append(bidx) + if not ok: + continue + all_bead_indices = [first] + rest + else: + all_bead_indices = [] + valid = True + for bead in bead_mapping: + bidx = _bead_indices(idx_mapping, bead) + if bidx is None: + valid = False + break + all_bead_indices.append(bidx) + if not valid or len(all_bead_indices) != len(bead_mapping): + continue result_res = self.cg_topology.add_residue(res.name, result_chain) - - # for bead_name, bead_type, bead_mass in zip(map_def.bead_atom_name[res.name], map_def.bead_type[res.name], map_def.bead_mass[res.name]): - for bead_name in map_def.bead_atom_names[res.name]: - self.cg_topology.add_atom(bead_name, mdtraj.element.carbon, result_res) - self.bead_atom_names.extend(map_def.bead_atom_names[res.name]) - self.bead_types.extend(map_def.bead_types[res.name]) - self.bead_mass.extend(map_def.bead_masses[res.name]) - self.embeddings.extend(map_def.bead_embeddings[res.name]) - - for bead in bead_mapping: - bead_idx = [] - for atom in bead: - if atom not in idx_mapping: - # FIXME: The martini mappings seem to have extra atoms (possibly to handle different naming schemes?) - raise RuntimeError(f"Missing atom: {res}, {atom}") - else: - bead_idx.append(idx_mapping[atom]) - + for bead_name in map_def.bead_atom_names[mapped_resname]: + if bead_name == "DBB": + element = mdtraj.element.phosphorus + elif str(bead_name).startswith("DB"): + element = mdtraj.element.nitrogen + elif str(bead_name).startswith("CA"): + element = mdtraj.element.carbon + else: + fc = str(bead_name)[0].upper() if bead_name else "C" + elmap = { + "P": mdtraj.element.phosphorus, + "N": mdtraj.element.nitrogen, + "C": mdtraj.element.carbon, + } + element = elmap.get(fc, mdtraj.element.carbon) + self.cg_topology.add_atom(bead_name, element, result_res) + self.bead_atom_names.extend(map_def.bead_atom_names[mapped_resname]) + self.bead_types.extend(map_def.bead_types[mapped_resname]) + self.bead_mass.extend(map_def.bead_masses[mapped_resname]) + self.embeddings.extend(map_def.bead_embeddings[mapped_resname]) + for bead_idx in all_bead_indices: self.src_idx.append(bead_idx) - # FIXME: Should use OpenMM masses not mdtraj's - bead_weights = np.array([topology.atom(i).element.mass for i in bead_idx]) - bead_weights = (bead_weights / np.sum(bead_weights)).tolist() - self.pos_weights.append(bead_weights) - self.force_weights.append(bead_weights) - + bead_w = np.array([topology.atom(i).element.mass for i in bead_idx]) + bead_w = (bead_w / np.sum(bead_w)).tolist() + self.pos_weights.append(bead_w) + self.force_weights.append(bead_w) if last_backbone_idx is not None: - # Add a backbone bond between the first bead of each residue - self.cg_topology.add_bond(self.cg_topology.atom(last_backbone_idx), self.cg_topology.atom(first_bead_idx)) - # Sequential bonds from the backbone to the other beads - for i in range(len(bead_mapping)-1): - self.cg_topology.add_bond(self.cg_topology.atom(first_bead_idx+i), self.cg_topology.atom(first_bead_idx+i+1)) + self.cg_topology.add_bond( + self.cg_topology.atom(last_backbone_idx), self.cg_topology.atom(backbone_idx) + ) + for i in range(len(all_bead_indices) - 1): + self.cg_topology.add_bond( + self.cg_topology.atom(backbone_idx + i), self.cg_topology.atom(backbone_idx + i + 1) + ) last_backbone_idx = backbone_idx def to_mol(self, bonds=True, angles=True, dihedrals=True): @@ -145,7 +209,27 @@ def to_mol(self, bonds=True, angles=True, dihedrals=True): mol.beta = np.full((self.cg_topology.n_atoms,), 0.0, dtype=np.float32) #pyright: ignore[reportAttributeAccessIssue] mol.record = np.full((self.cg_topology.n_atoms,), 'ATOM', dtype=object) #pyright: ignore[reportAttributeAccessIssue] mol.altloc = np.full((self.cg_topology.n_atoms,), '', dtype=object) #pyright: ignore[reportAttributeAccessIssue] - mol.element = np.full((self.cg_topology.n_atoms,), 'C', dtype=object) #pyright: ignore[reportAttributeAccessIssue] + # PSF: CG bead names; DNA DBB/DB* use P/N for writer compatibility (cgschnet) + mol.element = np.array( + [ + "P" if n == "DBB" else "N" if str(n).startswith("DB") else "C" if str(n).startswith("CA") else "C" + for n in self.bead_atom_names + ], + dtype=object, + ) # pyright: ignore[reportAttributeAccessIssue] + mol.atomicnumber = np.array( + [15 if n == "DBB" else 7 if str(n).startswith("DB") else 6 for n in self.bead_atom_names], + dtype=np.int32, + ) + disp_masses: list[float] = [] + for n in self.bead_atom_names: + if n == "DBB": + disp_masses.append(30.97) + elif str(n).startswith("DB"): + disp_masses.append(14.01) + else: + disp_masses.append(12.01) + mol.masses = np.array(disp_masses, dtype=np.float32) # pyright: ignore[reportAttributeAccessIssue] mol.formalcharge = np.full((self.cg_topology.n_atoms,), 0, dtype=np.int32) #pyright: ignore[reportAttributeAccessIssue] # The output psf contains resname=res_abbr, name=CA, atomtype=bead_type @@ -154,7 +238,7 @@ def to_mol(self, bonds=True, angles=True, dihedrals=True): mol.resname = np.array([a.residue.name for a in self.cg_topology.atoms], dtype=object) #pyright: ignore[reportAttributeAccessIssue] mol.charge = np.full((self.cg_topology.n_atoms,), 0) #pyright: ignore[reportAttributeAccessIssue] - mol.masses = np.array(self.bead_mass) #pyright: ignore[reportAttributeAccessIssue] + mol._cg_bead_masses = np.array(self.bead_mass, dtype=np.float32) # pyright: ignore[reportAttributeAccessIssue] mol.box = np.zeros((3, 0), dtype=np.float32) diff --git a/make_deltaforces.py b/make_deltaforces.py index 90dba68..82fb0eb 100644 --- a/make_deltaforces.py +++ b/make_deltaforces.py @@ -11,6 +11,7 @@ # from torchmd.forcefields.forcefield import ForceField from module.external_nn import ExternalNN, ParametersNN +from module.frame_utils import split_frame_indices from module.torchmd import tagged_forcefield # from simulate import CalcWrapper @@ -18,21 +19,6 @@ # adapted from https://github.com/torchmd/torchmd-cg/blob/master/torchmd_cg/utils/make_deltaforces.py -def _split_frame_indices(frames: list[int], n_workers: int) -> list[list[int]]: - n = len(frames) - if n_workers <= 1 or n <= 1: - return [frames] - n_workers = min(n_workers, n) - q, r = divmod(n, n_workers) - out: list[list[int]] = [] - idx = 0 - for w in range(n_workers): - take = q + (1 if w < r else 0) - out.append(frames[idx : idx + take]) - idx += take - return out - - def _classical_prior_chunk_worker( psf: str, coords_npz: str, @@ -166,7 +152,7 @@ def computePriorForces( assert len(pot) == 1 self.prior_energies[i] += pot[0] else: - chunks = _split_frame_indices(frames_list, effective_workers) + chunks = split_frame_indices(frames_list, effective_workers) box_arg = self._box_npz_path tasks = [ (