Source code for rdkit2ase.substructure

import typing as tp

import ase
import matplotlib.pyplot as plt
from ase.build import separate
from rdkit import Chem
from rdkit.Chem import Draw

from rdkit2ase.ase2x import ase2rdkit
from rdkit2ase.utils import find_connected_components


[docs] def match_substructure( atoms: ase.Atoms, smiles: str | None = None, smarts: str | None = None, mol: Chem.Mol | None = None, fragment: ase.Atoms | None = None, **kwargs, ) -> tuple[tuple[int, ...]]: """ Find all matches of a substructure pattern in a given ASE Atoms object. Parameters ---------- atoms : ase.Atoms The molecule or structure in which to search for substructure matches. smiles : str, optional A SMILES string representing the substructure pattern to match. smarts : str, optional A SMARTS string representing the substructure pattern to match. mol : Chem.Mol, optional An RDKit Mol object representing the substructure pattern to match. fragment : ase.Atoms, optional An ASE Atoms object representing the substructure pattern to match. If provided, it will be converted to an RDKit Mol object for matching. **kwargs Additional keyword arguments passed to `ase2rdkit`. Returns ------- tuple of tuple of int A tuple of atom index tuples, each corresponding to one match of the pattern. """ pattern = None if smiles is not None: pattern = Chem.MolFromSmiles(smiles) pattern = Chem.AddHs(pattern) # Ensure hydrogens are added for matching if smarts is not None: if pattern is not None: raise ValueError("Can only specify one pattern") pattern = Chem.MolFromSmarts(smarts) if mol is not None: if pattern is not None: raise ValueError("Can only specify one pattern") pattern = mol if fragment is not None: if pattern is not None: raise ValueError("Can only specify one pattern") pattern = ase2rdkit(fragment, **kwargs) if pattern is None: raise ValueError("Must specify a pattern") Chem.SanitizeMol(pattern) mol = ase2rdkit(atoms, **kwargs) matches = mol.GetSubstructMatches(pattern) return matches
[docs] def get_substructures( atoms: ase.Atoms, **kwargs, ) -> list[ase.Atoms]: """ Extract all matched substructures from an ASE Atoms object. Parameters ---------- atoms : ase.Atoms The structure to search in. smarts : str, optional A SMARTS string to match substructures. smiles : str, optional A SMILES string to match substructures. mol : Chem.Mol, optional An RDKit Mol object to match substructures. fragment : ase.Atoms, optional A specific ASE Atoms object to match against the structure. **kwargs Additional keyword arguments passed to `match_substructure`. Returns ------- list of ase.Atoms List of substructure fragments matching the pattern. """ return [atoms[match] for match in match_substructure(atoms, **kwargs)]
[docs] def iter_fragments(atoms: ase.Atoms) -> list[ase.Atoms]: """ Iterate over connected molecular fragments in an ASE Atoms object. If a 'connectivity' field is present in `atoms.info`, it will be used to determine fragments. Otherwise, `ase.build.separate` will be used. Parameters ---------- atoms : ase.Atoms A structure that may contain one or more molecular fragments. Yields ------ ase.Atoms Each connected component (fragment) in the input structure. """ if "connectivity" in atoms.info: # connectivity is a list of tuples (i, j, bond_type) connectivity = atoms.info["connectivity"] for component in find_connected_components(connectivity): yield atoms[list(component)] else: for molecule in separate(atoms): yield molecule
[docs] def select_atoms_grouped( # noqa: C901 mol: Chem.Mol, smarts_or_smiles: str, hydrogens: tp.Literal["include", "exclude", "isolated"] = "exclude", ) -> list[list[int]]: """Selects atom indices using SMARTS or SMILES, grouped by disconnected fragments. This function identifies all substructure matches and returns a list of atom index lists. Each inner list corresponds to a unique, disconnected molecular fragment that contained at least one match. If the pattern contains atom maps (e.g., "[C:1]", "[C:2]"), only the mapped atoms are returned, ordered by their map numbers. Map numbers must be unique within the pattern. Otherwise, all atoms in the matched substructures are returned. Parameters ---------- mol : rdchem.Mol RDKit molecule, which can contain multiple disconnected fragments and explicit hydrogens. smarts_or_smiles : str SMARTS pattern (e.g., "[F]") or SMILES with atom maps (e.g., "CC(=O)N[C:1]([C:2])[C:3](=O)[N:4]C"). When using mapped atoms, map numbers must be unique. hydrogens : {'include', 'exclude', 'isolated'}, default='exclude' How to handle hydrogens in the final returned list for each group: - 'include': Add hydrogens bonded to selected heavy atoms after each mapped atom. - 'exclude': Remove all hydrogens from the selection. - 'isolated': Return only the hydrogens that are bonded to selected heavy atoms. Returns ------- list[list[int]] A list of integer lists. Each inner list contains the atom indices for a matched, disconnected fragment. For mapped patterns, atoms are ordered by their map numbers. Fragments with no matches are omitted from the output. Raises ------ ValueError If the provided SMARTS/SMILES pattern is invalid or if atom map labels are used multiple times within the same pattern. Examples -------- >>> # Molecule with two disconnected fragments: ethanol and fluoromethane >>> mol = Chem.MolFromSmiles("CCO.CF") # Indices: C(0)C(1)O(2) . C(3)F(4) >>> >>> # Select all carbon atoms >>> select_atoms_grouped(mol, "[C]") [[0, 1], [3]] >>> >>> # Select fluorine and its bonded carbon using 'include' >>> select_atoms_grouped(mol, "[F]", hydrogens="include") [[3, 4]] """ patt = Chem.MolFromSmarts(smarts_or_smiles) if patt is None: # Support mapped SMILES patterns too patt = Chem.MolFromSmiles(smarts_or_smiles) if patt is None: raise ValueError(f"Invalid SMARTS/SMILES: {smarts_or_smiles}") # Get mapped indices from the pattern, if any, and validate uniqueness mapped_pattern_indices = [] atom_map_numbers = [] for atom in patt.GetAtoms(): if atom.GetAtomMapNum() > 0: map_num = atom.GetAtomMapNum() if map_num in atom_map_numbers: raise ValueError(f"Label '{map_num}' is used multiple times") atom_map_numbers.append(map_num) mapped_pattern_indices.append(atom.GetIdx()) # If we have mapped atoms, we need to sort them by their # map numbers to preserve order if mapped_pattern_indices: # Create pairs of (map_number, pattern_index) and sort by map_number map_index_pairs = [ (patt.GetAtomWithIdx(idx).GetAtomMapNum(), idx) for idx in mapped_pattern_indices ] map_index_pairs.sort(key=lambda x: x[0]) # Sort by map number mapped_pattern_indices = [idx for _, idx in map_index_pairs] # Find all matches in the entire molecule just once for efficiency all_matches = mol.GetSubstructMatches(patt) if not all_matches: return [] # Get the indices of atoms in each disconnected fragment fragment_sets = [set(frag) for frag in Chem.GetMolFrags(mol, asMols=False)] grouped_indices = [] for fragment_atom_indices in fragment_sets: # Filter matches to include only those fully contained within the fragment fragment_matches = [ match for match in all_matches if set(match).issubset(fragment_atom_indices) ] if not fragment_matches: continue # 1. Get the core set of atoms for this fragment. If the pattern is mapped, # use only the indices corresponding to mapped atoms. Otherwise, use all. if mapped_pattern_indices: # For mapped patterns, preserve the order of atoms based # on their map numbers core_atom_indices_ordered = [] for match_tuple in fragment_matches: match_atoms = [ match_tuple[pattern_idx] for pattern_idx in mapped_pattern_indices ] core_atom_indices_ordered.extend(match_atoms) # Remove duplicates while preserving order seen = set() core_atom_indices_ordered = [ x for x in core_atom_indices_ordered if not (x in seen or seen.add(x)) ] core_atom_indices = set(core_atom_indices_ordered) else: core_atom_indices = { idx for match_tuple in fragment_matches for idx in match_tuple } core_atom_indices_ordered = sorted(core_atom_indices) if not core_atom_indices: continue # 2. Handle the `hydrogens` parameter for this fragment's core atoms if hydrogens not in ("include", "exclude", "isolated"): raise ValueError( f"Invalid value for `hydrogens`: {hydrogens!r}. " "Expected one of 'include', 'exclude', 'isolated'." ) if hydrogens == "include": # Include both core atoms and their hydrogens, maintaining order final_indices_ordered = [] for idx in core_atom_indices_ordered: # Add the core atom first final_indices_ordered.append(idx) # Then add its hydrogens atom = mol.GetAtomWithIdx(idx) if atom.GetAtomicNum() != 1: # is a heavy atom hydrogen_indices = [ neighbor.GetIdx() for neighbor in atom.GetNeighbors() if neighbor.GetAtomicNum() == 1 ] final_indices_ordered.extend(sorted(hydrogen_indices)) elif hydrogens == "exclude": # Only heavy atoms from core selection final_indices_ordered = [ idx for idx in core_atom_indices_ordered if mol.GetAtomWithIdx(idx).GetAtomicNum() != 1 ] elif hydrogens == "isolated": # Only hydrogens bonded to core heavy atoms, maintaining order final_indices_ordered = [] for idx in core_atom_indices_ordered: atom = mol.GetAtomWithIdx(idx) if atom.GetAtomicNum() != 1: # is a heavy atom hydrogen_indices = [ neighbor.GetIdx() for neighbor in atom.GetNeighbors() if neighbor.GetAtomicNum() == 1 ] final_indices_ordered.extend(sorted(hydrogen_indices)) # Only add the group if it contains any atoms after processing if final_indices_ordered: grouped_indices.append(final_indices_ordered) return grouped_indices
[docs] def select_atoms_flat_unique( mol: Chem.Mol, smarts_or_smiles: str, hydrogens: tp.Literal["include", "exclude", "isolated"] = "exclude", ) -> list[int]: """ Selects a unique list of atom indices in a molecule using SMARTS or mapped SMILES. If the pattern contains atom maps (e.g., [C:1]), only the mapped atoms are returned. Otherwise, all atoms in the matched substructure are returned. Parameters ---------- mol : Chem.Mol RDKit molecule, which can contain explicit hydrogens. smarts_or_smiles : str SMARTS (e.g., "[F]") or SMILES with atom maps (e.g., "C1[C:1]OC(=[O:1])O1"). hydrogens : {"include", "exclude", "isolated"}, default "exclude" How to handle hydrogens in the final returned list. - "include": Include hydrogens attached to matched heavy atoms - "exclude": Exclude all hydrogens from results (default) - "isolated": Return only hydrogens attached to matched heavy atoms Returns ------- list[int] A single, flat list of unique integer atom indices matching the criteria. Raises ------ ValueError If the SMARTS/SMILES pattern is invalid. """ grouped_indices = select_atoms_grouped(mol, smarts_or_smiles, hydrogens=hydrogens) if not grouped_indices: return [] # Flatten the list of lists and remove duplicates unique_indices = set() for group in grouped_indices: unique_indices.update(group) return sorted(unique_indices)
def _collect_highlighted_fragments(mol, args, alpha): """Helper function to collect and process fragment highlights.""" frags = Chem.GetMolFrags(mol, asMols=True) frag_indices = Chem.GetMolFrags(mol, asMols=False) candidate_mols = [] candidate_highlights = [] candidate_colors = [] # Collect all selected indices from all argument lists all_selected_indices = set() for atom_list in args: all_selected_indices.update(atom_list) # Get colors from matplotlib's tab10 colormap and add alpha colors = plt.cm.tab10.colors highlight_colors = [colors[i % len(colors)] + (alpha,) for i in range(len(args))] for i, frag in enumerate(frags): original_indices_in_frag = set(frag_indices[i]) # Check if this fragment contains any of the selected atoms if not all_selected_indices.isdisjoint(original_indices_in_frag): candidate_mols.append(frag) # Map original indices to the new indices within the fragment original_to_frag_map = { orig_idx: new_idx for new_idx, orig_idx in enumerate(frag_indices[i]) } current_highlights = [] current_colors = {} # Process each argument list with its corresponding color for arg_idx, atom_list in enumerate(args): color = highlight_colors[arg_idx] for idx in atom_list: if idx in original_to_frag_map: frag_idx = original_to_frag_map[idx] if frag_idx not in current_highlights: current_highlights.append(frag_idx) # Later argument lists take precedence for coloring current_colors[frag_idx] = color candidate_highlights.append(current_highlights) candidate_colors.append(current_colors) return candidate_mols, candidate_highlights, candidate_colors def _filter_unique_molecules(candidate_mols, candidate_highlights, candidate_colors): """Helper function to filter for unique molecular structures.""" mols_to_draw = [] highlight_lists = [] highlight_colors = [] seen_smiles = set() for i, candidate_mol in enumerate(candidate_mols): # Generate canonical SMILES to identify unique structures mol_no_hs = Chem.RemoveHs(candidate_mol) smi = Chem.MolToSmiles(mol_no_hs, canonical=True) if smi not in seen_smiles: seen_smiles.add(smi) mols_to_draw.append(candidate_mol) highlight_lists.append(candidate_highlights[i]) highlight_colors.append(candidate_colors[i]) return mols_to_draw, highlight_lists, highlight_colors
[docs] def visualize_selected_molecules( mol: Chem.Mol, *args, mols_per_row: int = 4, sub_img_size: tuple[int, int] = (200, 200), legends: list[str] | None = None, alpha: float = 0.5, ): """ Visualizes molecules with optional atom highlighting. If no atom selections are provided, displays the molecule without highlights. Duplicate molecular structures will only be plotted once. Parameters ---------- mol : Chem.Mol The RDKit molecule object, which may contain multiple fragments. *args : list[int] Variable number of lists containing atom indices to be highlighted. Each list will be assigned a different color from matplotlib's tab10 colormap. If no arguments provided, displays the molecule without highlights. mols_per_row : int, default 4 Number of molecules per row in the grid. sub_img_size : tuple[int, int], default (200, 200) Size of each molecule image. legends : list[str] | None, default None Custom legends for each molecule. If None, default legends will be used. alpha : float, default 0.5 Transparency level for the highlighted atoms (0.0 = fully transparent, 1.0 = opaque). Returns ------- PIL.Image A PIL image object of the grid. """ # Handle empty args case - display molecule without highlights if not args: img = Draw.MolsToGridImage( [mol], molsPerRow=mols_per_row, subImgSize=sub_img_size, legends=legends if legends is not None else ["Molecule 0"], ) return img # Collect highlighted fragments candidate_mols, candidate_highlights, candidate_colors = ( _collect_highlighted_fragments(mol, args, alpha) ) if not candidate_mols: print("No molecules to draw with the given selections.") return None # Filter for unique molecules mols_to_draw, highlight_lists, highlight_colors = _filter_unique_molecules( candidate_mols, candidate_highlights, candidate_colors ) # Draw the grid final_legends = ( legends if legends is not None else [f"Molecule {i}" for i in range(len(mols_to_draw))] ) img = Draw.MolsToGridImage( mols_to_draw, molsPerRow=mols_per_row, subImgSize=sub_img_size, legends=final_legends, highlightAtomLists=highlight_lists, highlightAtomColors=highlight_colors, ) return img