Skip to main content
This page provides complete, runnable examples for different molecular optimization tasks.

Basic Example: TPSA + Weight Optimization

This example from chemlactica/mol_opt/example_run.py optimizes molecules for high topological polar surface area (TPSA) with molecular weight and ring constraints.

Oracle Implementation

mol_opt/example_run.py
from typing import List
import yaml
import argparse
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np
from rdkit.Chem import rdMolDescriptors
from chemlactica.mol_opt.optimization import optimize
from chemlactica.mol_opt.utils import set_seed, MoleculeEntry


class TPSA_Weight_Oracle:
    def __init__(self, max_oracle_calls: int):
        # Maximum number of oracle calls to make
        self.max_oracle_calls = max_oracle_calls

        # The frequence with which to log
        self.freq_log = 100

        # The buffer to keep track of all unique molecules generated
        self.mol_buffer = {}

        # The maximum possible oracle score or an upper bound
        self.max_possible_oracle_score = 800

        # Request MoleculeEntry objects instead of SMILES
        self.takes_entry = True

    def __call__(self, molecules: List[MoleculeEntry]):
        """
        Evaluate and return the oracle scores for molecules.
        Log the intermediate results if necessary.
        """
        oracle_scores = []
        for molecule in molecules:
            if self.mol_buffer.get(molecule.smiles):
                oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0]))
            else:
                try:
                    tpsa = rdMolDescriptors.CalcTPSA(molecule.mol)
                    oracle_score = tpsa
                    weight = rdMolDescriptors.CalcExactMolWt(molecule.mol)
                    num_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
                    
                    # Apply constraints
                    if weight >= 350:
                        oracle_score = 0
                    if num_rings < 2:
                        oracle_score = 0

                except Exception as e:
                    print(e)
                    oracle_score = 0
                
                self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1]
                if len(self.mol_buffer) % 100 == 0:
                    self.log_intermediate()
                oracle_scores.append(oracle_score)
        return oracle_scores
    
    def log_intermediate(self):
        scores = [v[0] for v in self.mol_buffer.values()][-self.max_oracle_calls:]
        scores_sorted = sorted(scores, reverse=True)[:100]
        n_calls = len(self.mol_buffer)

        score_avg_top1 = np.max(scores_sorted)
        score_avg_top10 = np.mean(scores_sorted[:10])
        score_avg_top100 = np.mean(scores_sorted)

        print(f"{n_calls}/{self.max_oracle_calls} | "
              f'avg_top1: {score_avg_top1:.3f} | '
              f'avg_top10: {score_avg_top10:.3f} | '
              f'avg_top100: {score_avg_top100:.3f}')

    def __len__(self):
        return len(self.mol_buffer)

    @property
    def budget(self):
        return self.max_oracle_calls

    @property
    def finish(self):
        return len(self.mol_buffer) >= self.max_oracle_calls

Running the Optimization

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--config_default", type=str, required=True)
    parser.add_argument("--n_runs", type=int, required=False, default=1)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_arguments()
    config = yaml.safe_load(open(args.config_default))

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        config["checkpoint_path"],
        torch_dtype=torch.bfloat16
    ).to(config["device"])
    
    tokenizer = AutoTokenizer.from_pretrained(
        config["tokenizer_path"],
        padding_side="left"
    )

    # Run multiple optimization runs with different seeds
    seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
    for i in range(args.n_runs):
        set_seed(seeds[i])
        oracle = TPSA_Weight_Oracle(max_oracle_calls=1000)
        config["log_dir"] = os.path.join(
            args.output_dir,
            f"results_chemlactica_tpsa+weight+num_rings_{seeds[i]}.log"
        )
        config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
        
        optimize(model, tokenizer, oracle, config)

Usage

python chemlactica/mol_opt/example_run.py \
    --output_dir results/ \
    --config_default chemlactica/mol_opt/chemlactica_125m_hparams.yaml \
    --n_runs 3
This will run 3 optimization runs with different random seeds and save logs to results/.

QED Optimization Example

Optimize molecules for drug-likeness using the QED (Quantitative Estimate of Drug-likeness) metric.
from rdkit import Chem
from rdkit.Chem import QED
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from chemlactica.mol_opt.optimization import optimize


class QED_Oracle:
    def __init__(self, max_oracle_calls: int, target_qed: float = 0.9):
        self.max_oracle_calls = max_oracle_calls
        self.freq_log = 100
        self.mol_buffer = {}
        self.max_possible_oracle_score = 1.0  # QED ranges from 0 to 1
        self.target_qed = target_qed
        self.success_count = 0

    def __call__(self, smiles_list):
        oracle_scores = []
        for smiles in smiles_list:
            if smiles in self.mol_buffer:
                oracle_scores.append(self.mol_buffer[smiles])
            else:
                try:
                    mol = Chem.MolFromSmiles(smiles)
                    qed_score = QED.qed(mol)
                    
                    # Track success rate (molecules above target)
                    if qed_score >= self.target_qed:
                        self.success_count += 1
                    
                    oracle_score = qed_score
                except:
                    oracle_score = 0.0
                
                self.mol_buffer[smiles] = oracle_score
                if len(self.mol_buffer) % self.freq_log == 0:
                    self.log_intermediate()
                oracle_scores.append(oracle_score)
        
        return oracle_scores
    
    def log_intermediate(self):
        n_calls = len(self.mol_buffer)
        scores = list(self.mol_buffer.values())
        scores_sorted = sorted(scores, reverse=True)[:100]
        
        print(f"{n_calls}/{self.max_oracle_calls} | "
              f"top1: {max(scores_sorted):.4f} | "
              f"top10: {sum(scores_sorted[:10])/10:.4f} | "
              f"success_rate: {self.success_count/n_calls*100:.1f}%")
    
    def __len__(self):
        return len(self.mol_buffer)
    
    @property
    def budget(self):
        return self.max_oracle_calls
    
    @property
    def finish(self):
        return len(self.mol_buffer) >= self.max_oracle_calls


# Setup and run
config = yaml.safe_load(open("chemlactica_125m_hparams.yaml"))

model = AutoModelForCausalLM.from_pretrained(
    config["checkpoint_path"],
    torch_dtype=torch.bfloat16
).to(config["device"])

tokenizer = AutoTokenizer.from_pretrained(
    config["tokenizer_path"],
    padding_side="left"
)

oracle = QED_Oracle(max_oracle_calls=10000, target_qed=0.9)
config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
config["log_dir"] = "results/qed_optimization.log"

optimize(model, tokenizer, oracle, config)

# Expected: ~99% success rate with 10K calls
print(f"Final success rate: {oracle.success_count/len(oracle)*100:.1f}%")

Similarity-Constrained Optimization

Generate molecules similar to a reference compound while optimizing a property.
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Descriptors
from chemlactica.mol_opt.utils import MoleculeEntry


class SimilarityConstrainedOracle:
    def __init__(self, max_oracle_calls: int, reference_smiles: str, 
                 min_similarity: float = 0.6):
        self.max_oracle_calls = max_oracle_calls
        self.freq_log = 100
        self.mol_buffer = {}
        self.max_possible_oracle_score = 10.0  # logP typically ranges -5 to 5
        self.takes_entry = True
        
        # Reference molecule
        self.reference_mol = Chem.MolFromSmiles(reference_smiles)
        self.reference_fp = AllChem.GetMorganFingerprintAsBitVect(
            self.reference_mol, 2, nBits=2048
        )
        self.min_similarity = min_similarity
    
    def __call__(self, molecules):
        oracle_scores = []
        for mol_entry in molecules:
            if mol_entry.smiles in self.mol_buffer:
                oracle_scores.append(self.mol_buffer[mol_entry.smiles])
            else:
                try:
                    # Calculate similarity to reference
                    similarity = DataStructs.TanimotoSimilarity(
                        mol_entry.fingerprint,
                        self.reference_fp
                    )
                    
                    # Calculate logP as objective
                    logp = Descriptors.MolLogP(mol_entry.mol)
                    
                    # Apply similarity constraint
                    if similarity >= self.min_similarity:
                        oracle_score = logp + 5  # Shift to positive range
                    else:
                        oracle_score = 0  # Penalty for low similarity
                    
                except:
                    oracle_score = 0
                
                self.mol_buffer[mol_entry.smiles] = oracle_score
                if len(self.mol_buffer) % self.freq_log == 0:
                    self.log_intermediate()
                oracle_scores.append(oracle_score)
        
        return oracle_scores
    
    def log_intermediate(self):
        print(f"{len(self.mol_buffer)}/{self.max_oracle_calls} molecules evaluated")
    
    def __len__(self):
        return len(self.mol_buffer)
    
    @property
    def budget(self):
        return self.max_oracle_calls
    
    @property
    def finish(self):
        return len(self.mol_buffer) >= self.max_oracle_calls


# Example: Optimize logP while maintaining similarity to aspirin
oracle = SimilarityConstrainedOracle(
    max_oracle_calls=5000,
    reference_smiles="CC(=O)OC1=CC=CC=C1C(=O)O",  # Aspirin
    min_similarity=0.6
)

Multi-Property Optimization

Optimize for multiple properties with weighted combination.
from rdkit.Chem import Descriptors, Crippen, QED
from rdkit.Chem.Scaffolds import MurckoScaffold
from chemlactica.mol_opt.utils import MoleculeEntry


class MultiPropertyOracle:
    def __init__(self, max_oracle_calls: int):
        self.max_oracle_calls = max_oracle_calls
        self.freq_log = 100
        self.mol_buffer = {}
        self.max_possible_oracle_score = 100.0
        self.takes_entry = True
        
        # Property weights
        self.weights = {
            'qed': 40.0,        # Drug-likeness
            'logp': 10.0,       # Lipophilicity (target ~3)
            'mw': 20.0,         # Molecular weight (target ~400)
            'tpsa': 20.0,       # TPSA (target ~80)
            'rotatable': 10.0   # Rotatable bonds (fewer is better)
        }
    
    def __call__(self, molecules):
        oracle_scores = []
        for mol_entry in molecules:
            if mol_entry.smiles in self.mol_buffer:
                oracle_scores.append(self.mol_buffer[mol_entry.smiles])
            else:
                try:
                    # Calculate properties
                    qed = QED.qed(mol_entry.mol)
                    logp = Crippen.MolLogP(mol_entry.mol)
                    mw = Descriptors.MolWt(mol_entry.mol)
                    tpsa = Descriptors.TPSA(mol_entry.mol)
                    n_rotatable = Descriptors.NumRotatableBonds(mol_entry.mol)
                    
                    # Normalize and score each property
                    score_qed = qed * self.weights['qed']
                    
                    # LogP: penalize deviation from 3
                    score_logp = max(0, (1 - abs(logp - 3) / 5)) * self.weights['logp']
                    
                    # MW: penalize deviation from 400
                    score_mw = max(0, (1 - abs(mw - 400) / 200)) * self.weights['mw']
                    
                    # TPSA: penalize deviation from 80
                    score_tpsa = max(0, (1 - abs(tpsa - 80) / 80)) * self.weights['tpsa']
                    
                    # Rotatable bonds: fewer is better
                    score_rotatable = max(0, (1 - n_rotatable / 10)) * self.weights['rotatable']
                    
                    # Total score
                    oracle_score = (
                        score_qed + score_logp + score_mw + 
                        score_tpsa + score_rotatable
                    )
                    
                except:
                    oracle_score = 0
                
                self.mol_buffer[mol_entry.smiles] = oracle_score
                if len(self.mol_buffer) % self.freq_log == 0:
                    self.log_intermediate()
                oracle_scores.append(oracle_score)
        
        return oracle_scores
    
    def log_intermediate(self):
        n_calls = len(self.mol_buffer)
        scores = list(self.mol_buffer.values())
        print(f"{n_calls}/{self.max_oracle_calls} | "
              f"max: {max(scores):.2f} | "
              f"top10_avg: {sum(sorted(scores, reverse=True)[:10])/10:.2f}")
    
    def __len__(self):
        return len(self.mol_buffer)
    
    @property
    def budget(self):
        return self.max_oracle_calls
    
    @property
    def finish(self):
        return len(self.mol_buffer) >= self.max_oracle_calls

Complete Script Template

Here’s a complete template you can adapt for your own objectives:
complete_template.py
import yaml
import argparse
import os
from typing import List
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from rdkit import Chem
from chemlactica.mol_opt.optimization import optimize
from chemlactica.mol_opt.utils import set_seed, MoleculeEntry


class CustomOracle:
    """Template for creating your own oracle."""
    
    def __init__(self, max_oracle_calls: int):
        self.max_oracle_calls = max_oracle_calls
        self.freq_log = 100
        self.mol_buffer = {}
        self.max_possible_oracle_score = 100.0  # Adjust for your objective
        self.takes_entry = True  # Set False if you only need SMILES
    
    def __call__(self, molecules: List[MoleculeEntry]):
        """Evaluate molecules and return scores."""
        oracle_scores = []
        
        for molecule in molecules:
            # Check cache
            if molecule.smiles in self.mol_buffer:
                oracle_scores.append(self.mol_buffer[molecule.smiles])
                continue
            
            # Calculate score for new molecule
            try:
                # Calculate your custom property here
                # For example: QED, SA score, or custom metric
                from rdkit.Chem import QED
                score = QED.qed(molecule.mol)
                
            except Exception as e:
                print(f"Error scoring {molecule.smiles}: {e}")
                score = 0.0
            
            # Store and log
            self.mol_buffer[molecule.smiles] = score
            if len(self.mol_buffer) % self.freq_log == 0:
                self.log_intermediate()
            
            oracle_scores.append(score)
        
        return oracle_scores
    
    def log_intermediate(self):
        """Log progress statistics."""
        n_calls = len(self.mol_buffer)
        scores = list(self.mol_buffer.values())
        scores_sorted = sorted(scores, reverse=True)[:100]
        
        print(f"{n_calls}/{self.max_oracle_calls} | "
              f"top1: {max(scores_sorted):.3f} | "
              f"top10: {np.mean(scores_sorted[:10]):.3f} | "
              f"top100: {np.mean(scores_sorted):.3f}")
    
    def __len__(self):
        return len(self.mol_buffer)
    
    @property
    def budget(self):
        return self.max_oracle_calls
    
    @property
    def finish(self):
        return len(self.mol_buffer) >= self.max_oracle_calls


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True,
                       help="Path to YAML config file")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Directory for output logs")
    parser.add_argument("--n_runs", type=int, default=1,
                       help="Number of runs with different seeds")
    parser.add_argument("--max_oracle_calls", type=int, default=5000,
                       help="Maximum oracle evaluations")
    args = parser.parse_args()
    
    # Load config
    config = yaml.safe_load(open(args.config))
    
    # Load model
    print(f"Loading model {config['checkpoint_path']}...")
    model = AutoModelForCausalLM.from_pretrained(
        config["checkpoint_path"],
        torch_dtype=torch.bfloat16
    ).to(config["device"])
    
    tokenizer = AutoTokenizer.from_pretrained(
        config["tokenizer_path"],
        padding_side="left"
    )
    
    # Run optimization(s)
    os.makedirs(args.output_dir, exist_ok=True)
    seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
    
    for i in range(args.n_runs):
        print(f"\n=== Run {i+1}/{args.n_runs} (seed={seeds[i]}) ===")
        set_seed(seeds[i])
        
        # Create oracle
        oracle = CustomOracle(max_oracle_calls=args.max_oracle_calls)
        config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
        config["log_dir"] = os.path.join(
            args.output_dir,
            f"optimization_run_{seeds[i]}.log"
        )
        
        # Run optimization
        optimize(model, tokenizer, oracle, config)
        
        # Print final statistics
        scores = list(oracle.mol_buffer.values())
        scores_sorted = sorted(scores, reverse=True)
        print(f"\nFinal Results (Run {i+1}):")
        print(f"  Total molecules evaluated: {len(oracle)}")
        print(f"  Best score: {scores_sorted[0]:.3f}")
        print(f"  Top 10 average: {np.mean(scores_sorted[:10]):.3f}")
        print(f"  Top 100 average: {np.mean(scores_sorted[:100]):.3f}")


if __name__ == "__main__":
    main()

Usage

python complete_template.py \
    --config chemlactica/mol_opt/chemlactica_125m_hparams.yaml \
    --output_dir results/my_optimization \
    --n_runs 5 \
    --max_oracle_calls 5000

Advanced: Docking Oracle Example

For computationally expensive oracles like docking, consider caching and parallelization:
from rdkit import Chem
from rdkit.Chem import AllChem
import subprocess
import os
import pickle


class DockingOracle:
    """Oracle for AutoDock Vina docking scores."""
    
    def __init__(self, max_oracle_calls: int, receptor_pdbqt: str,
                 cache_file: str = "docking_cache.pkl"):
        self.max_oracle_calls = max_oracle_calls
        self.freq_log = 10  # Log more frequently for slow oracles
        self.receptor_pdbqt = receptor_pdbqt
        self.cache_file = cache_file
        
        # Load cache if exists
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as f:
                self.mol_buffer = pickle.load(f)
            print(f"Loaded {len(self.mol_buffer)} cached results")
        else:
            self.mol_buffer = {}
        
        # Docking scores are negative (lower = better binding)
        # We negate so higher = better
        self.max_possible_oracle_score = 15.0
    
    def __call__(self, smiles_list):
        oracle_scores = []
        
        for smiles in smiles_list:
            if smiles in self.mol_buffer:
                oracle_scores.append(self.mol_buffer[smiles])
            else:
                try:
                    # Run docking
                    docking_score = self.run_vina_docking(smiles)
                    # Negate: lower binding energy = higher score
                    oracle_score = -docking_score
                except Exception as e:
                    print(f"Docking failed for {smiles}: {e}")
                    oracle_score = 0
                
                self.mol_buffer[smiles] = oracle_score
                
                # Save cache periodically
                if len(self.mol_buffer) % 10 == 0:
                    with open(self.cache_file, 'wb') as f:
                        pickle.dump(self.mol_buffer, f)
                
                if len(self.mol_buffer) % self.freq_log == 0:
                    self.log_intermediate()
                
                oracle_scores.append(oracle_score)
        
        return oracle_scores
    
    def run_vina_docking(self, smiles: str) -> float:
        """Run AutoDock Vina and return binding affinity.
        
        Note: This is a template. You'll need to:
        1. Install AutoDock Vina and required dependencies
        2. Convert SMILES to 3D structure using RDKit
        3. Save structure as PDBQT format
        4. Run Vina against your target protein
        5. Parse binding affinity from output
        
        See AutoDock Vina documentation: https://autodock-vina.readthedocs.io/
        """
        raise NotImplementedError("Implement Vina docking - see method docstring for steps")
    
    def log_intermediate(self):
        scores = list(self.mol_buffer.values())
        scores_sorted = sorted(scores, reverse=True)[:100]
        print(f"{len(self.mol_buffer)}/{self.max_oracle_calls} | "
              f"best_binding: {-max(scores_sorted):.2f} kcal/mol | "
              f"top10_avg: {-sum(scores_sorted[:10])/10:.2f} kcal/mol")
    
    def __len__(self):
        return len(self.mol_buffer)
    
    @property
    def budget(self):
        return self.max_oracle_calls
    
    @property
    def finish(self):
        return len(self.mol_buffer) >= self.max_oracle_calls

Tips for Your Own Examples

Begin with a simple oracle (e.g., single RDKit property) to verify your setup works before implementing complex objectives.
Always check mol_buffer before computing scores. Docking and other expensive calculations should never be repeated.
Wrap scoring in try-except and return 0 for failures. Log errors for debugging but don’t crash the optimization.
Implement good logging in log_intermediate(). Track top-1, top-10, and top-100 scores to see if optimization is working.
Run optimization 3-5 times with different seeds and report average/best results for reproducibility.
  • Fast oracles (RDKit properties): 10K-20K calls
  • Medium oracles (ML predictions): 5K-10K calls
  • Slow oracles (docking): 500-2K calls

Next Steps

Design Your Oracle

Learn more about oracle design patterns and best practices

Tune Hyperparameters

Optimize the configuration for your specific use case

Build docs developers (and LLMs) love