Skip to main content

Quick start

This guide walks you through creating a complete molecular optimization pipeline using ChemLactica. You’ll learn how to define an oracle function, configure the optimization algorithm, and generate molecules optimized for specific properties.

Overview

Molecular optimization with ChemLactica requires two main components:
  1. Oracle: Evaluates the quality of generated molecules based on your target properties
  2. Configuration: Defines hyperparameters for the optimization algorithm
The optimization process uses a genetic-like algorithm that iteratively generates and refines molecules to maximize the oracle score.

Step 1: Define your oracle

The oracle is responsible for evaluating molecules and assigning scores. Here’s a complete example that optimizes for TPSA (Topological Polar Surface Area) and molecular weight:
from typing import List
import numpy as np
from rdkit.Chem import rdMolDescriptors
from chemlactica.mol_opt.utils import MoleculeEntry

class TPSA_Weight_Oracle:
    def __init__(self, max_oracle_calls: int):
        # Maximum number of oracle calls to make
        self.max_oracle_calls = max_oracle_calls

        # The frequency with which to log
        self.freq_log = 100

        # The buffer to keep track of all unique molecules generated
        self.mol_buffer = {}

        # The maximum possible oracle score or an upper bound
        self.max_possible_oracle_score = 800

        # If True the __call__ function takes list of MoleculeEntry objects
        # If False (or unspecified) the __call__ function takes list of SMILES strings
        self.takes_entry = True

    def __call__(self, molecules: List[MoleculeEntry]):
        """
        Evaluate and return the oracle scores for molecules.
        Log the intermediate results if necessary.
        """
        oracle_scores = []
        for molecule in molecules:
            if self.mol_buffer.get(molecule.smiles):
                oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0]))
            else:
                try:
                    tpsa = rdMolDescriptors.CalcTPSA(molecule.mol)
                    oracle_score = tpsa
                    weight = rdMolDescriptors.CalcExactMolWt(molecule.mol)
                    num_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
                    
                    # Apply constraints
                    if weight >= 350:
                        oracle_score = 0
                    if num_rings < 2:
                        oracle_score = 0

                except Exception as e:
                    print(e)
                    oracle_score = 0
                    
                self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1]
                if len(self.mol_buffer) % 100 == 0:
                    self.log_intermediate()
                oracle_scores.append(oracle_score)
        return oracle_scores
    
    def log_intermediate(self):
        scores = [v[0] for v in self.mol_buffer.values()][-self.max_oracle_calls:]
        scores_sorted = sorted(scores, reverse=True)[:100]
        n_calls = len(self.mol_buffer)

        score_avg_top1 = np.max(scores_sorted)
        score_avg_top10 = np.mean(scores_sorted[:10])
        score_avg_top100 = np.mean(scores_sorted)

        print(f"{n_calls}/{self.max_oracle_calls} | ",
                f'avg_top1: {score_avg_top1:.3f} | '
                f'avg_top10: {score_avg_top10:.3f} | '
                f'avg_top100: {score_avg_top100:.3f}')

    def __len__(self):
        return len(self.mol_buffer)

    @property
    def budget(self):
        return self.max_oracle_calls

    @property
    def finish(self):
        """The stopping condition for the optimization process"""
        return len(self.mol_buffer) >= self.max_oracle_calls
The oracle’s __call__ method receives a list of MoleculeEntry objects and must return a list of scores. Set takes_entry = True to receive MoleculeEntry objects instead of SMILES strings.

Step 2: Create a configuration file

Create a YAML file to define the optimization hyperparameters. Here’s an example configuration for Chemlactica-125M:
model_hparams.yaml
checkpoint_path: yerevann/chemlactica-125m
tokenizer_path: yerevann/chemlactica-125m
pool_size: 10
validation_perc: 0.2
num_mols: 0
num_similars: 5
num_gens_per_iter: 200
device: cuda:0
sim_range: [0.4, 0.9]
num_processes: 8
generation_batch_size: 200
eos_token: "</s>"
generation_temperature: [1.0, 1.5]

generation_config:
  repetition_penalty: 1.0
  max_new_tokens: 100
  do_sample: true
  eos_token_id: 20

strategy: [rej-sample-v2]

rej_sample_config:
  train_tol_level: 3
  checkpoints_dir: checkpoints
  max_learning_rate: 0.0001
  lr_end: 0
  train_batch_size: 2
  gradient_accumulation_steps: 8
  weight_decay: 0.1
  adam_beta1: 0.9
  adam_beta2: 0.999
  warmup_steps: 10
  global_gradient_norm: 1.0
  dataloader_num_workers: 1
  max_seq_length: 2048
  num_train_epochs: 5
  packing: false

Key configuration parameters

  • pool_size: Number of top molecules to maintain in the pool
  • num_similars: Number of similar molecules to include in prompts
  • num_gens_per_iter: Number of molecules to generate per iteration
  • generation_temperature: Sampling temperature for diversity
  • strategy: Use [rej-sample-v2] for fine-tuning during optimization, or [default] for generation-only
Using strategy: [rej-sample-v2] enables rejection sampling with fine-tuning during optimization, which improves results but requires more computational resources.

Step 3: Run the optimization

Now put it all together in a complete script:
example_run.py
import yaml
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from chemlactica.mol_opt.optimization import optimize
from chemlactica.mol_opt.utils import set_seed

# Import your oracle class (from Step 1)
from oracle import TPSA_Weight_Oracle

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--config_default", type=str, required=True)
    parser.add_argument("--n_runs", type=int, required=False, default=1)
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_arguments()
    
    # Load configuration
    config = yaml.safe_load(open(args.config_default))

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        config["checkpoint_path"],
        torch_dtype=torch.bfloat16
    ).to(config["device"])
    
    tokenizer = AutoTokenizer.from_pretrained(
        config["tokenizer_path"],
        padding_side="left"
    )

    # Run optimization with different random seeds
    seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
    for i in range(args.n_runs):
        set_seed(seeds[i])
        
        # Create oracle
        oracle = TPSA_Weight_Oracle(max_oracle_calls=1000)
        
        # Set log directory
        config["log_dir"] = os.path.join(
            args.output_dir,
            f"results_chemlactica_tpsa+weight+num_rings_{seeds[i]}.log"
        )
        config["max_possible_oracle_score"] = oracle.max_possible_oracle_score
        
        # Run optimization
        optimize(model, tokenizer, oracle, config)

Step 4: Execute the optimization

Run your optimization script:
python example_run.py \
  --output_dir ./results \
  --config_default model_hparams.yaml \
  --n_runs 3
This will run the optimization 3 times with different random seeds and save the results to the ./results directory.

Understanding the output

During optimization, you’ll see periodic logs showing:
100/1000 | avg_top1: 156.234 | avg_top10: 142.891 | avg_top100: 128.456
200/1000 | avg_top1: 168.912 | avg_top10: 155.234 | avg_top100: 139.123
...
  • avg_top1: Best molecule score found so far
  • avg_top10: Average score of top 10 molecules
  • avg_top100: Average score of top 100 molecules

Advanced usage

Using different models

To use Chemlactica-1.3B or Chemma-2B, simply change the checkpoint paths in your config:
checkpoint_path: yerevann/chemlactica-1.3b
tokenizer_path: yerevann/chemlactica-1.3b

Custom property constraints

Modify the oracle to optimize for different properties:
# Optimize for QED (drug-likeness)
from rdkit.Chem import QED

class QED_Oracle:
    def __call__(self, molecules: List[MoleculeEntry]):
        oracle_scores = []
        for molecule in molecules:
            try:
                qed_score = QED.qed(molecule.mol)
                oracle_scores.append(qed_score)
            except:
                oracle_scores.append(0.0)
        return oracle_scores

Generation-only mode

For faster execution without fine-tuning, use the default strategy:
strategy: [default]
This disables the rejection sampling and fine-tuning steps.

Next steps

Advanced optimization

Explore complex oracles for docking, multi-objective optimization, and more

Fine-tuning

Learn how to fine-tune ChemLactica on your own molecular datasets

Test suite

Check out comprehensive examples in the ChemLacticaTestSuite repository

Paper

Read the full research paper for technical details

Build docs developers (and LLMs) love