Skip to main content

Overview

The optimize() function is the core of ChemLactica’s molecular optimization pipeline. It orchestrates the iterative process of generating molecules, scoring them with an oracle function, and fine-tuning the model using rejection sampling.

Function Signature

def optimize(
    model,
    tokenizer,
    oracle,
    config,
    additional_properties={},
    validate_smiles=lambda x: True
)

Parameters

model
transformers.PreTrainedModel
required
The pretrained language model for molecule generation. Typically a GPT-2 or OPT model fine-tuned on SMILES strings.
tokenizer
transformers.PreTrainedTokenizer
required
Tokenizer corresponding to the model, used to encode prompts and decode generated sequences.
oracle
callable
required
Scoring function that evaluates generated molecules. Can accept either:
  • A list of SMILES strings and return scores
  • A list of MoleculeEntry objects if oracle.takes_entry = True
Must have attributes:
  • max_oracle_calls: Maximum number of molecules to evaluate
  • finish: Boolean indicating if optimization should stop
  • mol_buffer: Dictionary tracking evaluated molecules
  • budget: Total oracle call budget
config
dict
required
Configuration dictionary containing optimization parameters:Required keys:
  • log_dir (str): Path to output log file
  • pool_size (int): Maximum size of molecule pool
  • validation_perc (float): Percentage of pool reserved for validation (0-1)
  • generation_config (dict): Parameters for model.generate() including max_new_tokens, temperature, etc.
  • generation_temperature (list): [min, max] temperature range for annealing
  • strategy (list): Optimization strategies, e.g., ["rej-sample-v2"]
  • num_gens_per_iter (int): Number of unique molecules to generate per iteration
  • generation_batch_size (int): Batch size for generation
  • num_mols (int): Number of molecules in each optimization entry
  • num_similars (int): Number of similar molecules to include in prompts
  • max_possible_oracle_score (float): Maximum achievable oracle score
  • sim_range (list): [min, max] similarity range for generation prompts
  • eos_token (str): End-of-sequence token for the model
Required for rejection sampling (rej-sample-v2):
  • rej_sample_config (dict): Training configuration including:
    • train_batch_size (int)
    • gradient_accumulation_steps (int)
    • num_train_epochs (int)
    • train_tol_level (int): Tolerance iterations before triggering training
    • rej_perc (float): Rejection sampling percentage
    • packing (bool): Whether to pack sequences
    • max_seq_length (int): Maximum sequence length
additional_properties
dict
default:"{}"
Dictionary of additional molecular properties to include in prompts. Each property should have:
  • start_tag (str): Opening tag for the property
  • end_tag (str): Closing tag for the property
  • calculate_value (callable): Function to compute property value from MoleculeEntry
  • infer_value (callable): Function to infer property value for generation prompts
validate_smiles
callable
default:"lambda x: True"
Function to validate generated SMILES strings. Takes a SMILES string and returns True if valid, False otherwise.

Returns

The function does not return a value. Instead, it:
  • Writes optimization progress to the log file specified in config["log_dir"]
  • Updates the model weights in-place through fine-tuning
  • Maintains the oracle’s mol_buffer with all evaluated molecules

Behavior

Iterative Optimization Loop

  1. Generation Phase: Creates prompts from pool molecules and generates new candidates
  2. Scoring Phase: Evaluates unique molecules with the oracle function
  3. Pool Update: Adds high-scoring molecules to the pool with diversity filtering
  4. Fine-tuning Phase: Triggers model training when tolerance threshold is reached

Temperature Annealing

Generation temperature increases linearly from generation_temperature[0] to generation_temperature[1] based on oracle budget consumption:
temperature += (num_gens_per_iter / (budget - num_gens_per_iter)) * (temp_max - temp_min)

Training Triggers

For rej-sample-v2 strategy, model fine-tuning occurs when:
  • tol_level >= config["rej_sample_config"]["train_tol_level"]
  • tol_level increments each iteration without finding a new best molecule
  • tol_level resets to 0 after training or finding a new best molecule

Example Usage

from chemlactica.mol_opt.optimization import optimize
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("allenai/chemlactica-125m")
tokenizer = AutoTokenizer.from_pretrained("allenai/chemlactica-125m")

# Define oracle function
class Oracle:
    def __init__(self):
        self.max_oracle_calls = 1000
        self.budget = 1000
        self.mol_buffer = {}
        self.finish = False
    
    def __call__(self, smiles_list):
        scores = []
        for smiles in smiles_list:
            score = compute_score(smiles)  # Your scoring function
            self.mol_buffer[smiles] = (score, len(self.mol_buffer))
            scores.append(score)
            if len(self.mol_buffer) >= self.max_oracle_calls:
                self.finish = True
        return scores
    
    def __len__(self):
        return len(self.mol_buffer)

oracle = Oracle()

# Configuration
config = {
    "log_dir": "optimization.log",
    "pool_size": 100,
    "validation_perc": 0.1,
    "generation_config": {
        "max_new_tokens": 512,
        "do_sample": True,
        "top_k": 50,
    },
    "generation_temperature": [1.0, 1.5],
    "strategy": ["rej-sample-v2"],
    "num_gens_per_iter": 50,
    "generation_batch_size": 10,
    "num_mols": 3,
    "num_similars": 5,
    "max_possible_oracle_score": 1.0,
    "sim_range": [0.3, 0.7],
    "eos_token": "</s>",
    "rej_sample_config": {
        "train_batch_size": 8,
        "gradient_accumulation_steps": 4,
        "num_train_epochs": 3,
        "train_tol_level": 3,
        "rej_perc": 0.5,
        "packing": False,
        "max_seq_length": 2048,
    }
}

# Run optimization
optimize(model, tokenizer, oracle, config)
  • MoleculeEntry - Represents individual molecules
  • OptimEntry - Represents optimization entries with multiple molecules
  • Pool - Manages the pool of high-scoring molecules

Build docs developers (and LLMs) love