Skip to main content
The optimize() function implements a genetic-like optimization algorithm that iteratively generates and refines molecules based on oracle feedback.

The optimize() Function

Located in chemlactica/mol_opt/optimization.py, this is the main entry point:
from chemlactica.mol_opt.optimization import optimize

optimize(
    model,              # Pre-trained ChemLactica model
    tokenizer,          # Corresponding tokenizer
    oracle,             # Your custom oracle
    config,             # Hyperparameter configuration
    additional_properties={},  # Optional additional properties
    validate_smiles=lambda x: True  # Optional SMILES validation
)

Parameters

model
AutoModelForCausalLM
required
The ChemLactica model (125M, 1.3B, or Chemma-2B)
tokenizer
AutoTokenizer
required
Tokenizer corresponding to the model
oracle
Oracle
required
Your custom oracle implementing the oracle interface
config
dict
required
Hyperparameter configuration dictionary (from YAML)
additional_properties
dict
default:"{}"
Optional additional properties to include in prompts (e.g., molecular weight, logP)
validate_smiles
callable
default:"lambda x: True"
Optional function to validate generated SMILES before scoring

Algorithm Workflow

The optimization process follows these steps:
1

Initialize Pool

Create an empty pool to store the top-performing molecules
pool = Pool(config["pool_size"], validation_perc=config["validation_perc"])
The pool maintains:
  • Top pool_size molecules sorted by score
  • Training/validation split for fine-tuning
  • Diversity filtering to avoid duplicates
2

Generation Loop

Generate num_gens_per_iter new molecules per iteration:a) Create prompts from pool molecules:
prompts = [
    optim_entry.to_prompt(
        is_generation=True,
        include_oracle_score=prev_train_iter != 0,
        config=config,
        max_score=max_score
    )
    for optim_entry in optim_entries
]
Each prompt includes:
  • Similar molecules with Tanimoto similarities
  • Desired oracle score (random between max_score and max_possible_oracle_score)
  • Additional properties if specified
b) Generate molecules:
output = model.generate(**data, **config["generation_config"])
output_texts = tokenizer.batch_decode(output)
c) Parse and validate:
molecule = create_molecule_entry(output_text, validate_smiles)
Extract SMILES between [START_SMILES] and [END_SMILES] tags
3

Oracle Evaluation

Score unique molecules with the oracle:
if oracle.takes_entry:
    oracle_scores = oracle([optim_entry.last_entry for smiles in unique_smiles])
else:
    oracle_scores = oracle(unique_smiles)
Track best score and log results:
if score > max_score:
    max_score = score
    new_best_molecule_generated = True
4

Update Pool

Add new molecules to the pool:
pool.add(list(iter_unique_optim_entries.values()))
The pool:
  • Sorts molecules by score (descending)
  • Removes duplicates and highly similar molecules
  • Keeps only top pool_size molecules
  • Maintains train/validation split
5

Adaptive Fine-tuning (Optional)

If using rej-sample-v2 strategy and no improvement for train_tol_level iterations:a) Prepare datasets:
train_entries, validation_entries = pool.get_train_valid_entries()

train_dataset = Dataset.from_dict({
    "sample": [
        entry.to_prompt(is_generation=False, include_oracle_score=True, ...)
        for entry in train_entries
    ]
})
b) Fine-tune model:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    ...
)
trainer.train()
model.load_state_dict(best_model_state_dict)
The model learns to generate molecules that match the oracle’s preferences.
6

Adjust Temperature

Increase sampling temperature over time for exploration:
config["generation_config"]["temperature"] += 
    (temp_end - temp_start) / num_iterations
Starts at generation_temperature[0] and increases to generation_temperature[1]
7

Check Stopping Condition

Stop when oracle budget is exhausted:
if oracle.finish:
    break

Key Components

Pool Management

The Pool class (mol_opt/utils.py:88) maintains top molecules:
class Pool:
    def __init__(self, size, validation_perc: float):
        self.size = size  # Maximum pool size
        self.optim_entries: List[OptimEntry] = []
        self.num_validation_entries = int(size * validation_perc + 1)
    
    def add(self, entries: List, diversity_score=1.0):
        # Sort by score
        self.optim_entries.extend(entries)
        self.optim_entries.sort(key=lambda x: x.last_entry, reverse=True)
        
        # Remove duplicates and similar molecules
        new_optim_entries = []
        for entry in self.optim_entries:
            insert = True
            for e in new_optim_entries:
                # Skip if duplicate or too similar
                if (entry.last_entry == e.last_entry or
                    tanimoto_dist_func(entry.last_entry.fingerprint, 
                                     e.last_entry.fingerprint) > diversity_score):
                    insert = False
                    break
            if insert:
                new_optim_entries.append(entry)
        
        # Keep top molecules only
        self.optim_entries = new_optim_entries[:min(len(new_optim_entries), self.size)]

Prompt Construction

The OptimEntry.to_prompt() method creates prompts for generation: During generation (is_generation=True):
</s>[SIMILAR]CCO 0.65[/SIMILAR][SIMILAR]CC(C)O 0.58[/SIMILAR][PROPERTY]oracle_score 85.23[/PROPERTY][START_SMILES]
During training (is_generation=False):
</s>[SIMILAR]CCO 0.65[/SIMILAR][SIMILAR]CC(C)O 0.58[/SIMILAR][PROPERTY]oracle_score 85.23[/PROPERTY][START_SMILES]CC(O)CO[END_SMILES]</s>

Molecule Entry

The MoleculeEntry class represents a molecule:
class MoleculeEntry:
    def __init__(self, smiles, score=0, **kwargs):
        self.smiles = canonicalize(smiles)  # Canonical SMILES
        self.score = score                   # Oracle score
        self.mol = Chem.MolFromSmiles(smiles)  # RDKit mol
        self.fingerprint = get_morgan_fingerprint(self.mol)  # For similarity
        self.similar_mol_entries = []        # Similar mols in prompt
        self.add_props = kwargs              # Additional properties

Tolerance Mechanism

The algorithm tracks iterations without improvement:
tol_level = 0  # Tolerance counter

if new_best_molecule_generated:
    tol_level = 0  # Reset if we find a better molecule
else:
    tol_level += 1  # Increment if no improvement

# Trigger fine-tuning after train_tol_level iterations without improvement
if tol_level >= config["rej_sample_config"]["train_tol_level"]:
    # Fine-tune model...
    tol_level = 0  # Reset after fine-tuning
This adaptive mechanism ensures:
  • Continued exploration when making progress
  • Model adaptation when stuck
  • Efficient use of oracle budget

Generation Temperature Schedule

Temperature increases linearly to encourage exploration:
# Initial temperature
config["generation_config"]["temperature"] = config["generation_temperature"][0]

# Each iteration
config["generation_config"]["temperature"] += 
    config["num_gens_per_iter"] / (oracle.budget - config["num_gens_per_iter"]) * 
    (config["generation_temperature"][1] - config["generation_temperature"][0])
Why increase temperature?
  • Early: Low temperature (1.0) focuses on likely, high-quality molecules
  • Late: High temperature (1.5) explores diverse, creative solutions
  • This schedule balances exploitation and exploration

Diversity Filtering

The pool removes molecules that are too similar:
for entry in self.optim_entries:
    insert = True
    for existing in new_optim_entries:
        similarity = tanimoto_dist_func(
            entry.last_entry.fingerprint,
            existing.last_entry.fingerprint
        )
        if similarity > diversity_score:  # Default: 1.0 (only exact duplicates)
            insert = False
            break
By default, only exact duplicates are removed. You can increase diversity_score in the pool to filter more aggressively.

Algorithm Pseudocode

Here’s the complete algorithm in pseudocode:
INPUT: model, tokenizer, oracle, config
OUTPUT: Optimized molecules in oracle.mol_buffer

1. Initialize:
   pool ← empty Pool(pool_size)
   max_score ← 0
   tol_level ← 0
   temperature ← config.generation_temperature[0]

2. While not oracle.finish:
   
   a. Generate molecules:
      unique_molecules ← {}
      while len(unique_molecules) < num_gens_per_iter:
          prompts ← create_prompts_from_pool(pool)
          outputs ← model.generate(prompts, temperature)
          molecules ← parse_and_validate(outputs)
          unique_molecules.add(new molecules not in oracle.mol_buffer)
   
   b. Evaluate with oracle:
      scores ← oracle(unique_molecules)
      if max(scores) > max_score:
          max_score ← max(scores)
          tol_level ← 0
      else:
          tol_level += 1
   
   c. Update pool:
      pool.add(molecules with scores)
      pool ← top pool_size molecules (sorted, filtered for diversity)
   
   d. Fine-tune if needed (rej-sample-v2 only):
      if tol_level >= train_tol_level:
          train_data, val_data ← pool.get_train_valid_split()
          model ← fine_tune(model, train_data, val_data)
          tol_level ← 0
   
   e. Increase temperature:
      temperature ← temperature + delta

3. Return oracle.mol_buffer

Comparison with Other Approaches

Similarities:
  • Maintains a pool of top candidates
  • Iterative generation and selection
  • Fitness-based ranking
Differences:
  • Uses LLM generation instead of crossover/mutation
  • Prompts guide generation with similarity constraints
  • Optional model fine-tuning for adaptation
Similarities:
  • Oracle acts as reward function
  • Model learns to maximize rewards
  • Exploration-exploitation tradeoff
Differences:
  • No RL training loop (uses supervised fine-tuning)
  • Simpler implementation
  • Faster convergence on many benchmarks
Similarities:
  • Efficient use of oracle budget
  • Balances exploration and exploitation
Differences:
  • No surrogate model
  • LLM directly generates candidates
  • Better handles discrete molecular space

Performance Characteristics

Memory Usage

  • Pool: ~10-100 molecules (negligible)
  • Model: 125M (~0.5GB), 1.3B (~5GB), 2B (~8GB)
  • Fine-tuning: +2x model size during training

Oracle Calls

  • Typical budget: 1K - 10K calls
  • Batch size: 200 molecules/iteration
  • PMO benchmark: ~5K calls to SOTA

Runtime

  • Generation: ~1-2 sec/batch (200 molecules)
  • Fine-tuning: ~30 sec - 2 min per round
  • Total: Minutes to hours depending on oracle complexity

Success Rate

  • QED optimization: 99% success (10K calls)
  • PMO benchmark: 17.5 avg score
  • Docking: 3-4x fewer calls than SOTA

Next Steps

Configure Hyperparameters

Learn how to tune pool size, temperature, and fine-tuning settings

Complete Examples

See full working examples with different oracles

Build docs developers (and LLMs) love