Skip to content

Instantly share code, notes, and snippets.

@metric-space
Last active February 11, 2026 06:53
Show Gist options
  • Select an option

  • Save metric-space/cedce32ad2cacafe44a4a60a9bfb1a5f to your computer and use it in GitHub Desktop.

Select an option

Save metric-space/cedce32ad2cacafe44a4a60a9bfb1a5f to your computer and use it in GitHub Desktop.
def _compute_prolif_interactions(
self,
mol: Chem.Mol,
pocket_complex: PocketComplex,
) -> np.ndarray:
"""Run prolif on a ligand + protein pocket and return the interaction array.
Uses the same interaction types and parameters as the training data
pipeline, ensuring consistent atom indexing.
Returns:
Array of shape [n_pocket_atoms, n_ligand_atoms, n_interaction_types].
"""
holo_mol = pocket_complex.holo.to_prolif()
ligand_mol = plf.Molecule.from_rdkit(mol, resname="LIG", resnumber=-1)
plf_fp = plf.Fingerprint(
interactions=PROLIF_INTERACTIONS,
parameters=INTERACTION_PARAMETERS,
count=True,
)
plf_fp.run_from_iterable(
[ligand_mol], holo_mol, residues="all", progress=False
)
atom_interactions = BindingInteractions._interactions_from_ifp(
plf_fp.ifp[0]
)
plf_interaction_map = {
t: i for i, t in enumerate(PROLIF_INTERACTIONS)
}
arr_shape = (
len(pocket_complex.holo),
mol.GetNumAtoms(),
len(PROLIF_INTERACTIONS),
)
interaction_arr = np.zeros(arr_shape, dtype=np.int8)
for int_type, p_idx, l_idx in atom_interactions:
int_idx = plf_interaction_map[int_type]
if p_idx < arr_shape[0] and l_idx < arr_shape[1]:
interaction_arr[p_idx, l_idx, int_idx] = 1
return interaction_arr
@staticmethod
def _check_geometric_fidelity(
condition_mol: Chem.Mol,
gen_mol: Chem.Mol,
interaction_array: np.ndarray,
epsilon: float = 1.0,
) -> Optional[float]:
"""Check whether conditioned atoms preserve their position and element.
For each ligand atom involved in any interaction (from the interaction
array), checks whether the generated atom has the same element and is
within *epsilon* Angstroms of the condition atom's position.
Args:
condition_mol: Reference ligand used to condition generation
gen_mol: Generated ligand (same atom count and ordering)
interaction_array: Shape [n_pocket, n_ligand, n_interaction_types]
epsilon: Distance tolerance in Angstroms
Returns:
Fraction of conditioned atoms geometrically preserved (0.0–1.0),
or None if no conditioned atoms exist.
"""
indices = np.where(interaction_array.sum(axis=(0, 2)) > 0)[0]
if len(indices) == 0:
return None
conf_cond = condition_mol.GetConformer()
conf_gen = gen_mol.GetConformer()
matches = sum(
1
for i in indices.tolist()
if (
condition_mol.GetAtomWithIdx(i).GetSymbol()
== gen_mol.GetAtomWithIdx(i).GetSymbol()
and np.linalg.norm(
np.array(conf_cond.GetAtomPosition(i))
- np.array(conf_gen.GetAtomPosition(i))
)
< epsilon
)
)
return matches / len(indices)
def _check_interaction_fidelity(
self,
gen_mol: Chem.Mol,
pocket_complex: PocketComplex,
ref_interaction_array: np.ndarray,
) -> Optional[float]:
"""Check whether interaction-conditioned atoms preserve their interactions.
Uses prolif (the same tool that computed the conditioning) to extract
the generated ligand's interaction array and compares it against the
reference at conditioned atom positions.
Args:
gen_mol: Generated (or relaxed) ligand
pocket_complex: The PocketComplex containing the protein pocket
ref_interaction_array: Shape [n_pocket, n_ligand, n_types] from
the reference ligand (used to determine conditioned atoms)
Returns:
Fraction of conditioned interaction entries preserved (0.0–1.0),
or None if no conditioned atoms or no active reference interactions.
"""
conditioned_mask = ref_interaction_array.sum(axis=(0, 2)) > 0
if not conditioned_mask.any():
return None
try:
gen_array = self._compute_prolif_interactions(
gen_mol, pocket_complex
)
except Exception as e:
print(f"Prolif interaction computation failed: {e}")
return None
# Compare at conditioned ligand atoms only
ref_conditioned = ref_interaction_array[:, conditioned_mask, :]
gen_conditioned = gen_array[:, conditioned_mask, :]
ref_active = ref_conditioned > 0
n_ref = ref_active.sum()
if n_ref == 0:
return None
n_preserved = (ref_active & (gen_conditioned > 0)).sum()
return float(n_preserved) / float(n_ref)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment