Skip to content

Instantly share code, notes, and snippets.

@richwhitjr
Created December 14, 2025 20:12
Show Gist options
  • Select an option

  • Save richwhitjr/51e33d85587fb8eff469761826c8b09f to your computer and use it in GitHub Desktop.

Select an option

Save richwhitjr/51e33d85587fb8eff469761826c8b09f to your computer and use it in GitHub Desktop.
Synthon Spark Distributed Search
import argparse
from typing import Iterator
import pandas
from pyspark.sql import SparkSession, types, DataFrame
from rdkit import Chem
SYNTHON_SCHEMA = types.StructType([
types.StructField("synthon_key", types.StringType(), False),
types.StructField("synthon_smiles", types.StringType(), False),
])
def _generate_synthons_tuples() -> DataFrame:
return pandas.DataFrame(
[
(("Synthon1", "Synthon10", "Synthon266"), 0),
(("Synthon1", "Synthon11", "Synthon266"), 1),
],
columns=["synthons", "synthon_key"],
)
def _filter_mol(mol: Chem.Mol) -> bool:
# Placeholder for actual filtering logic
return True
def _iterate_synthons(synthons) -> Iterator[Chem.Mol]:
# Placeholder for actual synthon iteration logic
return iter([])
def _enumerate_synthons(iterator) -> Iterator[pandas.DataFrame]:
for pdf in iterator:
for _, row in pdf.iterrows():
rows = []
synthons = row["synthons"]
for mol in _iterate_synthons(synthons):
if _filter_mol(mol):
synthon_smiles = Chem.MolToSmiles(mol)
rows.append({
"synthon_key": row["synthon_key"],
"synthon_smiles": synthon_smiles,
})
yield pandas.DataFrame(rows)
def _parse_args():
p = argparse.ArgumentParser()
p.add_argument("--num-synthon-partitions", type=int, default=400)
p.add_argument("--out", required=True)
return p.parse_args()
def main():
args = _parse_args()
spark = (
SparkSession
.builder
.appName("RDKitDistributedSynthonSearch")
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
.getOrCreate()
)
# Generate synthons
synthons_tuples = spark.createDataFrame(_generate_synthons_tuples())
# Repartition synthons by synthon key to ensure that same synthons are in the same partition
synthons_shuffled = synthons_tuples.repartition(args.num_synthon_partitions, "synthon_key")
# Enumerate synthons
molecules = synthons_shuffled.mapInPandas(_enumerate_synthons, schema=SYNTHON_SCHEMA)
# Repartition output to reduce number of output files
molecules.coalesce(1).write.mode("overwrite").option("header", "true").csv(args.out)
spark.stop()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment