From a50e506d01549218fba38bd4fc0d9bd083730217 Mon Sep 17 00:00:00 2001 From: chrisjonesBSU Date: Thu, 14 Mar 2024 12:29:53 -0600 Subject: [PATCH 1/3] new func to trim snapshot mols --- cmeutils/gsd_utils.py | 96 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/cmeutils/gsd_utils.py b/cmeutils/gsd_utils.py index 88893d3..6f4b280 100644 --- a/cmeutils/gsd_utils.py +++ b/cmeutils/gsd_utils.py @@ -380,6 +380,102 @@ def xml_to_gsd(xmlfile, gsdfile): print(f"XML data written to {gsdfile}") +def trim_snapshot_molecules(parent_snapshot, mol_indices): + """Given a snapshot of a system, trim the snapshot to only include + a subset of the molecules. + + Parameters + ---------- + parent_snapshot : gsd.hoomd.Frame + The snapshot to read in. + mol_indices : list of np.ndarray + List of arrays where each array contains the indices + of the particles in a molecule to include. + + Returns + ------- + gsd.hoomd.Frame + The new snapshot with only the specified molecules. + + Notes + ----- + See cmetuils.gsd_utils.get_molecule_cluster for a method to obtain + mol_indices. + + """ + new_snap = gsd.hoomd.Frame() + new_snap.configuration.box = parent_snapshot.configuration.box + new_snap.particles.N = sum(len(i) for i in mol_indices) + + # Write out particle info + for attr in ["position", "mass", "velocity", "orientation", "image", "diameter", "angmom", "typeid"]: + setattr( + new_snap.particles, + attr, + np.concatenate( + list(getattr(parent_snapshot.particles, attr)[i] for i in mol_indices) + ) + ) + new_snap.particles.types = parent_snapshot.particles.types + + particle_index_map = dict() + count = 0 + for indices in mol_indices: + for i in indices: + particle_index_map[i] = count + count += 1 + + # Write out bond info + mol_bond_groups = [] + mol_bond_ids = [] + for count, indices in enumerate(mol_indices): + mask = np.any(np.isin(parent_snapshot.bonds.group, indices.flatten()), axis=1) + parent_mol_bonds = parent_snapshot.bonds.group[np.where(mask)[0]] + parent_mol_bond_typeids = parent_snapshot.bonds.typeid[np.where(mask)[0]] + new_mol_bonds = np.vectorize(particle_index_map.get)(parent_mol_bonds) + mol_bond_groups.append(new_mol_bonds) + mol_bond_ids.append(parent_mol_bond_typeids) + + new_snap.bonds.types = parent_snapshot.bonds.types + new_snap.bonds.group = np.concatenate(mol_bond_groups) + new_snap.bonds.typeid = np.concatenate(mol_bond_ids) + new_snap.bonds.N = sum(len(i) for i in mol_bond_ids) + + # Write out angle info + mol_angle_groups = [] + mol_angle_ids = [] + for count, indices in enumerate(mol_indices): + mask = np.any(np.isin(parent_snapshot.angles.group, indices.flatten()), axis=1) + parent_mol_angles = parent_snapshot.angles.group[np.where(mask)[0]] + parent_mol_angle_typeids = parent_snapshot.angles.typeid[np.where(mask)[0]] + new_mol_angles = np.vectorize(particle_index_map.get)(parent_mol_angles) + mol_angle_groups.append(new_mol_angles) + mol_angle_ids.append(parent_mol_angle_typeids) + + new_snap.angles.types = parent_snapshot.angles.types + new_snap.angles.group = np.concatenate(mol_angle_groups) + new_snap.angles.typeid = np.concatenate(mol_angle_ids) + new_snap.angles.N = sum(len(i) for i in mol_angle_ids) + + # Write out dihedral info + mol_dihedral_groups = [] + mol_dihedral_ids = [] + for count, indices in enumerate(mol_indices): + mask = np.any(np.isin(parent_snapshot.dihedrals.group, indices.flatten()), axis=1) + parent_mol_dihedrals = parent_snapshot.dihedrals.group[np.where(mask)[0]] + parent_mol_dihedral_typeids = parent_snapshot.dihedrals.typeid[np.where(mask)[0]] + new_mol_dihedrals = np.vectorize(particle_index_map.get)(parent_mol_dihedrals) + mol_dihedral_groups.append(new_mol_dihedrals) + mol_dihedral_ids.append(parent_mol_dihedral_typeids) + + new_snap.dihedrals.types = parent_snapshot.dihedrals.types + new_snap.dihedrals.group = np.concatenate(mol_dihedral_groups) + new_snap.dihedrals.typeid = np.concatenate(mol_dihedral_ids) + new_snap.dihedrals.N = sum(len(i) for i in mol_dihedral_ids) + + new_snap.validate() + return new_snap + def identify_snapshot_connections(snapshot): """Identify angle and dihedral connections in a snapshot from bonds. From 3e6019baf712cd9f297067e52a0f32fe5e6cc7a3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:51:39 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cmeutils/gsd_utils.py | 51 +++++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/cmeutils/gsd_utils.py b/cmeutils/gsd_utils.py index 6f4b280..0791460 100644 --- a/cmeutils/gsd_utils.py +++ b/cmeutils/gsd_utils.py @@ -408,13 +408,25 @@ def trim_snapshot_molecules(parent_snapshot, mol_indices): new_snap.particles.N = sum(len(i) for i in mol_indices) # Write out particle info - for attr in ["position", "mass", "velocity", "orientation", "image", "diameter", "angmom", "typeid"]: + for attr in [ + "position", + "mass", + "velocity", + "orientation", + "image", + "diameter", + "angmom", + "typeid", + ]: setattr( new_snap.particles, attr, np.concatenate( - list(getattr(parent_snapshot.particles, attr)[i] for i in mol_indices) - ) + list( + getattr(parent_snapshot.particles, attr)[i] + for i in mol_indices + ) + ), ) new_snap.particles.types = parent_snapshot.particles.types @@ -429,9 +441,13 @@ def trim_snapshot_molecules(parent_snapshot, mol_indices): mol_bond_groups = [] mol_bond_ids = [] for count, indices in enumerate(mol_indices): - mask = np.any(np.isin(parent_snapshot.bonds.group, indices.flatten()), axis=1) + mask = np.any( + np.isin(parent_snapshot.bonds.group, indices.flatten()), axis=1 + ) parent_mol_bonds = parent_snapshot.bonds.group[np.where(mask)[0]] - parent_mol_bond_typeids = parent_snapshot.bonds.typeid[np.where(mask)[0]] + parent_mol_bond_typeids = parent_snapshot.bonds.typeid[ + np.where(mask)[0] + ] new_mol_bonds = np.vectorize(particle_index_map.get)(parent_mol_bonds) mol_bond_groups.append(new_mol_bonds) mol_bond_ids.append(parent_mol_bond_typeids) @@ -445,9 +461,13 @@ def trim_snapshot_molecules(parent_snapshot, mol_indices): mol_angle_groups = [] mol_angle_ids = [] for count, indices in enumerate(mol_indices): - mask = np.any(np.isin(parent_snapshot.angles.group, indices.flatten()), axis=1) + mask = np.any( + np.isin(parent_snapshot.angles.group, indices.flatten()), axis=1 + ) parent_mol_angles = parent_snapshot.angles.group[np.where(mask)[0]] - parent_mol_angle_typeids = parent_snapshot.angles.typeid[np.where(mask)[0]] + parent_mol_angle_typeids = parent_snapshot.angles.typeid[ + np.where(mask)[0] + ] new_mol_angles = np.vectorize(particle_index_map.get)(parent_mol_angles) mol_angle_groups.append(new_mol_angles) mol_angle_ids.append(parent_mol_angle_typeids) @@ -461,10 +481,18 @@ def trim_snapshot_molecules(parent_snapshot, mol_indices): mol_dihedral_groups = [] mol_dihedral_ids = [] for count, indices in enumerate(mol_indices): - mask = np.any(np.isin(parent_snapshot.dihedrals.group, indices.flatten()), axis=1) - parent_mol_dihedrals = parent_snapshot.dihedrals.group[np.where(mask)[0]] - parent_mol_dihedral_typeids = parent_snapshot.dihedrals.typeid[np.where(mask)[0]] - new_mol_dihedrals = np.vectorize(particle_index_map.get)(parent_mol_dihedrals) + mask = np.any( + np.isin(parent_snapshot.dihedrals.group, indices.flatten()), axis=1 + ) + parent_mol_dihedrals = parent_snapshot.dihedrals.group[ + np.where(mask)[0] + ] + parent_mol_dihedral_typeids = parent_snapshot.dihedrals.typeid[ + np.where(mask)[0] + ] + new_mol_dihedrals = np.vectorize(particle_index_map.get)( + parent_mol_dihedrals + ) mol_dihedral_groups.append(new_mol_dihedrals) mol_dihedral_ids.append(parent_mol_dihedral_typeids) @@ -476,6 +504,7 @@ def trim_snapshot_molecules(parent_snapshot, mol_indices): new_snap.validate() return new_snap + def identify_snapshot_connections(snapshot): """Identify angle and dihedral connections in a snapshot from bonds. From 51c015d629a128b260a43b50fc5968d4608c3f0a Mon Sep 17 00:00:00 2001 From: chrisjonesBSU Date: Thu, 14 Mar 2024 12:54:51 -0600 Subject: [PATCH 3/3] remove un-used enumerations --- cmeutils/gsd_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmeutils/gsd_utils.py b/cmeutils/gsd_utils.py index 0791460..60a6723 100644 --- a/cmeutils/gsd_utils.py +++ b/cmeutils/gsd_utils.py @@ -440,7 +440,7 @@ def trim_snapshot_molecules(parent_snapshot, mol_indices): # Write out bond info mol_bond_groups = [] mol_bond_ids = [] - for count, indices in enumerate(mol_indices): + for indices in mol_indices: mask = np.any( np.isin(parent_snapshot.bonds.group, indices.flatten()), axis=1 ) @@ -460,7 +460,7 @@ def trim_snapshot_molecules(parent_snapshot, mol_indices): # Write out angle info mol_angle_groups = [] mol_angle_ids = [] - for count, indices in enumerate(mol_indices): + for indices in mol_indices: mask = np.any( np.isin(parent_snapshot.angles.group, indices.flatten()), axis=1 ) @@ -480,7 +480,7 @@ def trim_snapshot_molecules(parent_snapshot, mol_indices): # Write out dihedral info mol_dihedral_groups = [] mol_dihedral_ids = [] - for count, indices in enumerate(mol_indices): + for indices in mol_indices: mask = np.any( np.isin(parent_snapshot.dihedrals.group, indices.flatten()), axis=1 )