Skip to main content

Overview

Oracles are scoring functions that evaluate molecular properties during optimization. ChemLactica uses oracles to guide the generation of molecules with desired characteristics.

Oracle Structure

A custom oracle is a Python class that implements specific methods and properties. Here’s the basic structure:
from typing import List
from chemlactica.mol_opt.utils import MoleculeEntry

class CustomOracle:
    def __init__(self, max_oracle_calls: int):
        # Maximum number of oracle calls
        self.max_oracle_calls = max_oracle_calls
        
        # Logging frequency
        self.freq_log = 100
        
        # Buffer to track unique molecules
        self.mol_buffer = {}
        
        # Maximum possible score (upper bound)
        self.max_possible_oracle_score = 1.0
        
        # Set to True if __call__ takes MoleculeEntry objects
        self.takes_entry = True
    
    def __call__(self, molecules: List[MoleculeEntry]):
        """Evaluate and return oracle scores for molecules."""
        pass
    
    def __len__(self):
        return len(self.mol_buffer)
    
    @property
    def budget(self):
        return self.max_oracle_calls
    
    @property
    def finish(self):
        return len(self.mol_buffer) >= self.max_oracle_calls

Step-by-Step Implementation

1

Define the Oracle Class

Create a new class and initialize required attributes in __init__:
class TPSA_Weight_Oracle:
    def __init__(self, max_oracle_calls: int):
        self.max_oracle_calls = max_oracle_calls
        self.freq_log = 100
        self.mol_buffer = {}
        self.max_possible_oracle_score = 800
        self.takes_entry = True
Set takes_entry = True if your oracle needs access to RDKit molecule objects and fingerprints from MoleculeEntry.
2

Implement the Scoring Logic

Define the __call__ method to compute scores:
from rdkit.Chem import rdMolDescriptors

def __call__(self, molecules: List[MoleculeEntry]):
    oracle_scores = []
    for molecule in molecules:
        if self.mol_buffer.get(molecule.smiles):
            # Return cached score
            oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0]))
        else:
            try:
                # Calculate TPSA
                tpsa = rdMolDescriptors.CalcTPSA(molecule.mol)
                oracle_score = tpsa
                
                # Apply constraints
                weight = rdMolDescriptors.CalcExactMolWt(molecule.mol)
                num_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
                
                if weight >= 350:
                    oracle_score = 0
                if num_rings < 2:
                    oracle_score = 0
                    
            except Exception as e:
                print(e)
                oracle_score = 0
            
            # Cache the score
            self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1]
            
            # Log progress
            if len(self.mol_buffer) % 100 == 0:
                self.log_intermediate()
                
            oracle_scores.append(oracle_score)
    return oracle_scores
3

Add Logging Functionality

Implement log_intermediate() to track optimization progress:
import numpy as np

def log_intermediate(self):
    scores = [v[0] for v in self.mol_buffer.values()][-self.max_oracle_calls:]
    scores_sorted = sorted(scores, reverse=True)[:100]
    n_calls = len(self.mol_buffer)

    score_avg_top1 = np.max(scores_sorted)
    score_avg_top10 = np.mean(scores_sorted[:10])
    score_avg_top100 = np.mean(scores_sorted)

    print(f"{n_calls}/{self.max_oracle_calls} | ",
          f'avg_top1: {score_avg_top1:.3f} | '
          f'avg_top10: {score_avg_top10:.3f} | '
          f'avg_top100: {score_avg_top100:.3f}')
4

Implement Required Properties

Add the required properties and methods:
def __len__(self):
    return len(self.mol_buffer)

@property
def budget(self):
    return self.max_oracle_calls

@property
def finish(self):
    # Stopping condition for optimization
    return len(self.mol_buffer) >= self.max_oracle_calls

Using Your Oracle

Once your oracle is defined, use it with the optimization workflow:
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from chemlactica.mol_opt.optimization import optimize
from chemlactica.mol_opt.utils import set_seed

# Load configuration
config = yaml.safe_load(open("config.yaml"))

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    config["checkpoint_path"], 
    torch_dtype=torch.bfloat16
).to(config["device"])
tokenizer = AutoTokenizer.from_pretrained(
    config["tokenizer_path"], 
    padding_side="left"
)

# Initialize oracle
set_seed(42)
oracle = TPSA_Weight_Oracle(max_oracle_calls=1000)

# Set up config
config["log_dir"] = "results.log"
config["max_possible_oracle_score"] = oracle.max_possible_oracle_score

# Run optimization
optimize(model, tokenizer, oracle, config)

Oracle Configuration

Key configuration parameters for optimization (defined in your YAML config):
checkpoint_path: "path/to/model"
tokenizer_path: "path/to/tokenizer"
device: "cuda:0"

# Optimization settings
pool_size: 100
validation_perc: 0.2
num_gens_per_iter: 50
generation_batch_size: 32
num_mols: 3
num_similars: 5

# Generation settings
generation_temperature: [0.8, 1.2]
generation_config:
  max_new_tokens: 512
  do_sample: true
  top_k: 50
  top_p: 0.95

# Strategy
strategy: ["rej-sample-v2"]

Example: QED Oracle

Here’s a simpler oracle that maximizes QED (drug-likeness):
from rdkit.Chem.QED import qed

class QED_Oracle:
    def __init__(self, max_oracle_calls: int):
        self.max_oracle_calls = max_oracle_calls
        self.mol_buffer = {}
        self.max_possible_oracle_score = 1.0
        self.takes_entry = True
        
    def __call__(self, molecules: List[MoleculeEntry]):
        oracle_scores = []
        for molecule in molecules:
            if molecule.smiles in self.mol_buffer:
                oracle_scores.append(self.mol_buffer[molecule.smiles][0])
            else:
                try:
                    score = qed(molecule.mol)
                except:
                    score = 0.0
                self.mol_buffer[molecule.smiles] = [score, len(self.mol_buffer) + 1]
                oracle_scores.append(score)
        return oracle_scores
    
    def __len__(self):
        return len(self.mol_buffer)
    
    @property
    def budget(self):
        return self.max_oracle_calls
    
    @property
    def finish(self):
        return len(self.mol_buffer) >= self.max_oracle_calls

Best Practices

Always cache computed scores in mol_buffer to avoid redundant calculations:
if molecule.smiles in self.mol_buffer:
    return cached_score
else:
    compute_and_cache()
Wrap scoring calculations in try-except blocks to handle invalid molecules:
try:
    score = calculate_property(molecule.mol)
except Exception as e:
    print(e)
    score = 0  # Return penalty score
Set max_possible_oracle_score to help the model understand the score range:
# For QED (0-1 range)
self.max_possible_oracle_score = 1.0

# For custom scores, set an upper bound
self.max_possible_oracle_score = 800
Log intermediate results periodically to track progress:
if len(self.mol_buffer) % self.freq_log == 0:
    self.log_intermediate()
The takes_entry attribute determines whether __call__ receives MoleculeEntry objects (with .mol and .fingerprint attributes) or just SMILES strings.

Next Steps

Property Prediction

Learn how to fine-tune models for property prediction

Benchmarking

Evaluate your oracle’s performance

Build docs developers (and LLMs) love