Basic Example: TPSA + Weight Optimization
This example fromchemlactica/mol_opt/example_run.py optimizes molecules for high topological polar surface area (TPSA) with molecular weight and ring constraints.
Oracle Implementation
mol_opt/example_run.py
from typing import List
import yaml
import argparse
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np
from rdkit.Chem import rdMolDescriptors
from chemlactica.mol_opt.optimization import optimize
from chemlactica.mol_opt.utils import set_seed, MoleculeEntry
class TPSA_Weight_Oracle:
def __init__(self, max_oracle_calls: int):
# Maximum number of oracle calls to make
self.max_oracle_calls = max_oracle_calls
# The frequence with which to log
self.freq_log = 100
# The buffer to keep track of all unique molecules generated
self.mol_buffer = {}
# The maximum possible oracle score or an upper bound
self.max_possible_oracle_score = 800
# Request MoleculeEntry objects instead of SMILES
self.takes_entry = True
def __call__(self, molecules: List[MoleculeEntry]):
"""
Evaluate and return the oracle scores for molecules.
Log the intermediate results if necessary.
"""
oracle_scores = []
for molecule in molecules:
if self.mol_buffer.get(molecule.smiles):
oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0]))
else:
try:
tpsa = rdMolDescriptors.CalcTPSA(molecule.mol)
oracle_score = tpsa
weight = rdMolDescriptors.CalcExactMolWt(molecule.mol)
num_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
# Apply constraints
if weight >= 350:
oracle_score = 0
if num_rings < 2:
oracle_score = 0
except Exception as e:
print(e)
oracle_score = 0
self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1]
if len(self.mol_buffer) % 100 == 0:
self.log_intermediate()
oracle_scores.append(oracle_score)
return oracle_scores
def log_intermediate(self):
scores = [v[0] for v in self.mol_buffer.values()][-self.max_oracle_calls:]
scores_sorted = sorted(scores, reverse=True)[:100]
n_calls = len(self.mol_buffer)
score_avg_top1 = np.max(scores_sorted)
score_avg_top10 = np.mean(scores_sorted[:10])
score_avg_top100 = np.mean(scores_sorted)
print(f"{n_calls}/{self.max_oracle_calls} | "
f'avg_top1: {score_avg_top1:.3f} | '
f'avg_top10: {score_avg_top10:.3f} | '
f'avg_top100: {score_avg_top100:.3f}')
def __len__(self):
return len(self.mol_buffer)
@property
def budget(self):
return self.max_oracle_calls
@property
def finish(self):
return len(self.mol_buffer) >= self.max_oracle_calls
Running the Optimization
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--config_default", type=str, required=True)
parser.add_argument("--n_runs", type=int, required=False, default=1)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_arguments()
config = yaml.safe_load(open(args.config_default))
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
config["checkpoint_path"],
torch_dtype=torch.bfloat16
).to(config["device"])
tokenizer = AutoTokenizer.from_pretrained(
config["tokenizer_path"],
padding_side="left"
)
# Run multiple optimization runs with different seeds
seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
for i in range(args.n_runs):
set_seed(seeds[i])
oracle = TPSA_Weight_Oracle(max_oracle_calls=1000)
config["log_dir"] = os.path.join(
args.output_dir,
f"results_chemlactica_tpsa+weight+num_rings_{seeds[i]}.log"
)
config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
optimize(model, tokenizer, oracle, config)
Usage
python chemlactica/mol_opt/example_run.py \
--output_dir results/ \
--config_default chemlactica/mol_opt/chemlactica_125m_hparams.yaml \
--n_runs 3
This will run 3 optimization runs with different random seeds and save logs to
results/.QED Optimization Example
Optimize molecules for drug-likeness using the QED (Quantitative Estimate of Drug-likeness) metric.from rdkit import Chem
from rdkit.Chem import QED
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from chemlactica.mol_opt.optimization import optimize
class QED_Oracle:
def __init__(self, max_oracle_calls: int, target_qed: float = 0.9):
self.max_oracle_calls = max_oracle_calls
self.freq_log = 100
self.mol_buffer = {}
self.max_possible_oracle_score = 1.0 # QED ranges from 0 to 1
self.target_qed = target_qed
self.success_count = 0
def __call__(self, smiles_list):
oracle_scores = []
for smiles in smiles_list:
if smiles in self.mol_buffer:
oracle_scores.append(self.mol_buffer[smiles])
else:
try:
mol = Chem.MolFromSmiles(smiles)
qed_score = QED.qed(mol)
# Track success rate (molecules above target)
if qed_score >= self.target_qed:
self.success_count += 1
oracle_score = qed_score
except:
oracle_score = 0.0
self.mol_buffer[smiles] = oracle_score
if len(self.mol_buffer) % self.freq_log == 0:
self.log_intermediate()
oracle_scores.append(oracle_score)
return oracle_scores
def log_intermediate(self):
n_calls = len(self.mol_buffer)
scores = list(self.mol_buffer.values())
scores_sorted = sorted(scores, reverse=True)[:100]
print(f"{n_calls}/{self.max_oracle_calls} | "
f"top1: {max(scores_sorted):.4f} | "
f"top10: {sum(scores_sorted[:10])/10:.4f} | "
f"success_rate: {self.success_count/n_calls*100:.1f}%")
def __len__(self):
return len(self.mol_buffer)
@property
def budget(self):
return self.max_oracle_calls
@property
def finish(self):
return len(self.mol_buffer) >= self.max_oracle_calls
# Setup and run
config = yaml.safe_load(open("chemlactica_125m_hparams.yaml"))
model = AutoModelForCausalLM.from_pretrained(
config["checkpoint_path"],
torch_dtype=torch.bfloat16
).to(config["device"])
tokenizer = AutoTokenizer.from_pretrained(
config["tokenizer_path"],
padding_side="left"
)
oracle = QED_Oracle(max_oracle_calls=10000, target_qed=0.9)
config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
config["log_dir"] = "results/qed_optimization.log"
optimize(model, tokenizer, oracle, config)
# Expected: ~99% success rate with 10K calls
print(f"Final success rate: {oracle.success_count/len(oracle)*100:.1f}%")
Similarity-Constrained Optimization
Generate molecules similar to a reference compound while optimizing a property.from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Descriptors
from chemlactica.mol_opt.utils import MoleculeEntry
class SimilarityConstrainedOracle:
def __init__(self, max_oracle_calls: int, reference_smiles: str,
min_similarity: float = 0.6):
self.max_oracle_calls = max_oracle_calls
self.freq_log = 100
self.mol_buffer = {}
self.max_possible_oracle_score = 10.0 # logP typically ranges -5 to 5
self.takes_entry = True
# Reference molecule
self.reference_mol = Chem.MolFromSmiles(reference_smiles)
self.reference_fp = AllChem.GetMorganFingerprintAsBitVect(
self.reference_mol, 2, nBits=2048
)
self.min_similarity = min_similarity
def __call__(self, molecules):
oracle_scores = []
for mol_entry in molecules:
if mol_entry.smiles in self.mol_buffer:
oracle_scores.append(self.mol_buffer[mol_entry.smiles])
else:
try:
# Calculate similarity to reference
similarity = DataStructs.TanimotoSimilarity(
mol_entry.fingerprint,
self.reference_fp
)
# Calculate logP as objective
logp = Descriptors.MolLogP(mol_entry.mol)
# Apply similarity constraint
if similarity >= self.min_similarity:
oracle_score = logp + 5 # Shift to positive range
else:
oracle_score = 0 # Penalty for low similarity
except:
oracle_score = 0
self.mol_buffer[mol_entry.smiles] = oracle_score
if len(self.mol_buffer) % self.freq_log == 0:
self.log_intermediate()
oracle_scores.append(oracle_score)
return oracle_scores
def log_intermediate(self):
print(f"{len(self.mol_buffer)}/{self.max_oracle_calls} molecules evaluated")
def __len__(self):
return len(self.mol_buffer)
@property
def budget(self):
return self.max_oracle_calls
@property
def finish(self):
return len(self.mol_buffer) >= self.max_oracle_calls
# Example: Optimize logP while maintaining similarity to aspirin
oracle = SimilarityConstrainedOracle(
max_oracle_calls=5000,
reference_smiles="CC(=O)OC1=CC=CC=C1C(=O)O", # Aspirin
min_similarity=0.6
)
Multi-Property Optimization
Optimize for multiple properties with weighted combination.from rdkit.Chem import Descriptors, Crippen, QED
from rdkit.Chem.Scaffolds import MurckoScaffold
from chemlactica.mol_opt.utils import MoleculeEntry
class MultiPropertyOracle:
def __init__(self, max_oracle_calls: int):
self.max_oracle_calls = max_oracle_calls
self.freq_log = 100
self.mol_buffer = {}
self.max_possible_oracle_score = 100.0
self.takes_entry = True
# Property weights
self.weights = {
'qed': 40.0, # Drug-likeness
'logp': 10.0, # Lipophilicity (target ~3)
'mw': 20.0, # Molecular weight (target ~400)
'tpsa': 20.0, # TPSA (target ~80)
'rotatable': 10.0 # Rotatable bonds (fewer is better)
}
def __call__(self, molecules):
oracle_scores = []
for mol_entry in molecules:
if mol_entry.smiles in self.mol_buffer:
oracle_scores.append(self.mol_buffer[mol_entry.smiles])
else:
try:
# Calculate properties
qed = QED.qed(mol_entry.mol)
logp = Crippen.MolLogP(mol_entry.mol)
mw = Descriptors.MolWt(mol_entry.mol)
tpsa = Descriptors.TPSA(mol_entry.mol)
n_rotatable = Descriptors.NumRotatableBonds(mol_entry.mol)
# Normalize and score each property
score_qed = qed * self.weights['qed']
# LogP: penalize deviation from 3
score_logp = max(0, (1 - abs(logp - 3) / 5)) * self.weights['logp']
# MW: penalize deviation from 400
score_mw = max(0, (1 - abs(mw - 400) / 200)) * self.weights['mw']
# TPSA: penalize deviation from 80
score_tpsa = max(0, (1 - abs(tpsa - 80) / 80)) * self.weights['tpsa']
# Rotatable bonds: fewer is better
score_rotatable = max(0, (1 - n_rotatable / 10)) * self.weights['rotatable']
# Total score
oracle_score = (
score_qed + score_logp + score_mw +
score_tpsa + score_rotatable
)
except:
oracle_score = 0
self.mol_buffer[mol_entry.smiles] = oracle_score
if len(self.mol_buffer) % self.freq_log == 0:
self.log_intermediate()
oracle_scores.append(oracle_score)
return oracle_scores
def log_intermediate(self):
n_calls = len(self.mol_buffer)
scores = list(self.mol_buffer.values())
print(f"{n_calls}/{self.max_oracle_calls} | "
f"max: {max(scores):.2f} | "
f"top10_avg: {sum(sorted(scores, reverse=True)[:10])/10:.2f}")
def __len__(self):
return len(self.mol_buffer)
@property
def budget(self):
return self.max_oracle_calls
@property
def finish(self):
return len(self.mol_buffer) >= self.max_oracle_calls
Complete Script Template
Here’s a complete template you can adapt for your own objectives:complete_template.py
import yaml
import argparse
import os
from typing import List
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from rdkit import Chem
from chemlactica.mol_opt.optimization import optimize
from chemlactica.mol_opt.utils import set_seed, MoleculeEntry
class CustomOracle:
"""Template for creating your own oracle."""
def __init__(self, max_oracle_calls: int):
self.max_oracle_calls = max_oracle_calls
self.freq_log = 100
self.mol_buffer = {}
self.max_possible_oracle_score = 100.0 # Adjust for your objective
self.takes_entry = True # Set False if you only need SMILES
def __call__(self, molecules: List[MoleculeEntry]):
"""Evaluate molecules and return scores."""
oracle_scores = []
for molecule in molecules:
# Check cache
if molecule.smiles in self.mol_buffer:
oracle_scores.append(self.mol_buffer[molecule.smiles])
continue
# Calculate score for new molecule
try:
# Calculate your custom property here
# For example: QED, SA score, or custom metric
from rdkit.Chem import QED
score = QED.qed(molecule.mol)
except Exception as e:
print(f"Error scoring {molecule.smiles}: {e}")
score = 0.0
# Store and log
self.mol_buffer[molecule.smiles] = score
if len(self.mol_buffer) % self.freq_log == 0:
self.log_intermediate()
oracle_scores.append(score)
return oracle_scores
def log_intermediate(self):
"""Log progress statistics."""
n_calls = len(self.mol_buffer)
scores = list(self.mol_buffer.values())
scores_sorted = sorted(scores, reverse=True)[:100]
print(f"{n_calls}/{self.max_oracle_calls} | "
f"top1: {max(scores_sorted):.3f} | "
f"top10: {np.mean(scores_sorted[:10]):.3f} | "
f"top100: {np.mean(scores_sorted):.3f}")
def __len__(self):
return len(self.mol_buffer)
@property
def budget(self):
return self.max_oracle_calls
@property
def finish(self):
return len(self.mol_buffer) >= self.max_oracle_calls
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True,
help="Path to YAML config file")
parser.add_argument("--output_dir", type=str, required=True,
help="Directory for output logs")
parser.add_argument("--n_runs", type=int, default=1,
help="Number of runs with different seeds")
parser.add_argument("--max_oracle_calls", type=int, default=5000,
help="Maximum oracle evaluations")
args = parser.parse_args()
# Load config
config = yaml.safe_load(open(args.config))
# Load model
print(f"Loading model {config['checkpoint_path']}...")
model = AutoModelForCausalLM.from_pretrained(
config["checkpoint_path"],
torch_dtype=torch.bfloat16
).to(config["device"])
tokenizer = AutoTokenizer.from_pretrained(
config["tokenizer_path"],
padding_side="left"
)
# Run optimization(s)
os.makedirs(args.output_dir, exist_ok=True)
seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
for i in range(args.n_runs):
print(f"\n=== Run {i+1}/{args.n_runs} (seed={seeds[i]}) ===")
set_seed(seeds[i])
# Create oracle
oracle = CustomOracle(max_oracle_calls=args.max_oracle_calls)
config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
config["log_dir"] = os.path.join(
args.output_dir,
f"optimization_run_{seeds[i]}.log"
)
# Run optimization
optimize(model, tokenizer, oracle, config)
# Print final statistics
scores = list(oracle.mol_buffer.values())
scores_sorted = sorted(scores, reverse=True)
print(f"\nFinal Results (Run {i+1}):")
print(f" Total molecules evaluated: {len(oracle)}")
print(f" Best score: {scores_sorted[0]:.3f}")
print(f" Top 10 average: {np.mean(scores_sorted[:10]):.3f}")
print(f" Top 100 average: {np.mean(scores_sorted[:100]):.3f}")
if __name__ == "__main__":
main()
Usage
python complete_template.py \
--config chemlactica/mol_opt/chemlactica_125m_hparams.yaml \
--output_dir results/my_optimization \
--n_runs 5 \
--max_oracle_calls 5000
Advanced: Docking Oracle Example
For computationally expensive oracles like docking, consider caching and parallelization:from rdkit import Chem
from rdkit.Chem import AllChem
import subprocess
import os
import pickle
class DockingOracle:
"""Oracle for AutoDock Vina docking scores."""
def __init__(self, max_oracle_calls: int, receptor_pdbqt: str,
cache_file: str = "docking_cache.pkl"):
self.max_oracle_calls = max_oracle_calls
self.freq_log = 10 # Log more frequently for slow oracles
self.receptor_pdbqt = receptor_pdbqt
self.cache_file = cache_file
# Load cache if exists
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
self.mol_buffer = pickle.load(f)
print(f"Loaded {len(self.mol_buffer)} cached results")
else:
self.mol_buffer = {}
# Docking scores are negative (lower = better binding)
# We negate so higher = better
self.max_possible_oracle_score = 15.0
def __call__(self, smiles_list):
oracle_scores = []
for smiles in smiles_list:
if smiles in self.mol_buffer:
oracle_scores.append(self.mol_buffer[smiles])
else:
try:
# Run docking
docking_score = self.run_vina_docking(smiles)
# Negate: lower binding energy = higher score
oracle_score = -docking_score
except Exception as e:
print(f"Docking failed for {smiles}: {e}")
oracle_score = 0
self.mol_buffer[smiles] = oracle_score
# Save cache periodically
if len(self.mol_buffer) % 10 == 0:
with open(self.cache_file, 'wb') as f:
pickle.dump(self.mol_buffer, f)
if len(self.mol_buffer) % self.freq_log == 0:
self.log_intermediate()
oracle_scores.append(oracle_score)
return oracle_scores
def run_vina_docking(self, smiles: str) -> float:
"""Run AutoDock Vina and return binding affinity.
Note: This is a template. You'll need to:
1. Install AutoDock Vina and required dependencies
2. Convert SMILES to 3D structure using RDKit
3. Save structure as PDBQT format
4. Run Vina against your target protein
5. Parse binding affinity from output
See AutoDock Vina documentation: https://autodock-vina.readthedocs.io/
"""
raise NotImplementedError("Implement Vina docking - see method docstring for steps")
def log_intermediate(self):
scores = list(self.mol_buffer.values())
scores_sorted = sorted(scores, reverse=True)[:100]
print(f"{len(self.mol_buffer)}/{self.max_oracle_calls} | "
f"best_binding: {-max(scores_sorted):.2f} kcal/mol | "
f"top10_avg: {-sum(scores_sorted[:10])/10:.2f} kcal/mol")
def __len__(self):
return len(self.mol_buffer)
@property
def budget(self):
return self.max_oracle_calls
@property
def finish(self):
return len(self.mol_buffer) >= self.max_oracle_calls
Tips for Your Own Examples
Start simple
Start simple
Begin with a simple oracle (e.g., single RDKit property) to verify your setup works before implementing complex objectives.
Cache everything
Cache everything
Always check
mol_buffer before computing scores. Docking and other expensive calculations should never be repeated.Handle errors gracefully
Handle errors gracefully
Wrap scoring in try-except and return 0 for failures. Log errors for debugging but don’t crash the optimization.
Monitor progress
Monitor progress
Implement good logging in
log_intermediate(). Track top-1, top-10, and top-100 scores to see if optimization is working.Use multiple runs
Use multiple runs
Run optimization 3-5 times with different seeds and report average/best results for reproducibility.
Set realistic budgets
Set realistic budgets
- Fast oracles (RDKit properties): 10K-20K calls
- Medium oracles (ML predictions): 5K-10K calls
- Slow oracles (docking): 500-2K calls
Next Steps
Design Your Oracle
Learn more about oracle design patterns and best practices
Tune Hyperparameters
Optimize the configuration for your specific use case