Overview
Oracles are scoring functions that evaluate molecular properties during optimization. ChemLactica uses oracles to guide the generation of molecules with desired characteristics.
Oracle Structure
A custom oracle is a Python class that implements specific methods and properties. Here’s the basic structure:
from typing import List
from chemlactica.mol_opt.utils import MoleculeEntry
class CustomOracle :
def __init__ ( self , max_oracle_calls : int ):
# Maximum number of oracle calls
self .max_oracle_calls = max_oracle_calls
# Logging frequency
self .freq_log = 100
# Buffer to track unique molecules
self .mol_buffer = {}
# Maximum possible score (upper bound)
self .max_possible_oracle_score = 1.0
# Set to True if __call__ takes MoleculeEntry objects
self .takes_entry = True
def __call__ ( self , molecules : List[MoleculeEntry]):
"""Evaluate and return oracle scores for molecules."""
pass
def __len__ ( self ):
return len ( self .mol_buffer)
@ property
def budget ( self ):
return self .max_oracle_calls
@ property
def finish ( self ):
return len ( self .mol_buffer) >= self .max_oracle_calls
Step-by-Step Implementation
Define the Oracle Class
Create a new class and initialize required attributes in __init__: class TPSA_Weight_Oracle :
def __init__ ( self , max_oracle_calls : int ):
self .max_oracle_calls = max_oracle_calls
self .freq_log = 100
self .mol_buffer = {}
self .max_possible_oracle_score = 800
self .takes_entry = True
Set takes_entry = True if your oracle needs access to RDKit molecule objects and fingerprints from MoleculeEntry.
Implement the Scoring Logic
Define the __call__ method to compute scores: from rdkit.Chem import rdMolDescriptors
def __call__ ( self , molecules : List[MoleculeEntry]):
oracle_scores = []
for molecule in molecules:
if self .mol_buffer.get(molecule.smiles):
# Return cached score
oracle_scores.append( sum ( self .mol_buffer[molecule.smiles][ 0 ]))
else :
try :
# Calculate TPSA
tpsa = rdMolDescriptors.CalcTPSA(molecule.mol)
oracle_score = tpsa
# Apply constraints
weight = rdMolDescriptors.CalcExactMolWt(molecule.mol)
num_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
if weight >= 350 :
oracle_score = 0
if num_rings < 2 :
oracle_score = 0
except Exception as e:
print (e)
oracle_score = 0
# Cache the score
self .mol_buffer[molecule.smiles] = [oracle_score, len ( self .mol_buffer) + 1 ]
# Log progress
if len ( self .mol_buffer) % 100 == 0 :
self .log_intermediate()
oracle_scores.append(oracle_score)
return oracle_scores
Add Logging Functionality
Implement log_intermediate() to track optimization progress: import numpy as np
def log_intermediate ( self ):
scores = [v[ 0 ] for v in self .mol_buffer.values()][ - self .max_oracle_calls:]
scores_sorted = sorted (scores, reverse = True )[: 100 ]
n_calls = len ( self .mol_buffer)
score_avg_top1 = np.max(scores_sorted)
score_avg_top10 = np.mean(scores_sorted[: 10 ])
score_avg_top100 = np.mean(scores_sorted)
print ( f " { n_calls } / { self .max_oracle_calls } | " ,
f 'avg_top1: { score_avg_top1 :.3f} | '
f 'avg_top10: { score_avg_top10 :.3f} | '
f 'avg_top100: { score_avg_top100 :.3f} ' )
Implement Required Properties
Add the required properties and methods: def __len__ ( self ):
return len ( self .mol_buffer)
@ property
def budget ( self ):
return self .max_oracle_calls
@ property
def finish ( self ):
# Stopping condition for optimization
return len ( self .mol_buffer) >= self .max_oracle_calls
Using Your Oracle
Once your oracle is defined, use it with the optimization workflow:
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from chemlactica.mol_opt.optimization import optimize
from chemlactica.mol_opt.utils import set_seed
# Load configuration
config = yaml.safe_load( open ( "config.yaml" ))
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
config[ "checkpoint_path" ],
torch_dtype = torch.bfloat16
).to(config[ "device" ])
tokenizer = AutoTokenizer.from_pretrained(
config[ "tokenizer_path" ],
padding_side = "left"
)
# Initialize oracle
set_seed( 42 )
oracle = TPSA_Weight_Oracle( max_oracle_calls = 1000 )
# Set up config
config[ "log_dir" ] = "results.log"
config[ "max_possible_oracle_score" ] = oracle.max_possible_oracle_score
# Run optimization
optimize(model, tokenizer, oracle, config)
Oracle Configuration
Key configuration parameters for optimization (defined in your YAML config):
checkpoint_path : "path/to/model"
tokenizer_path : "path/to/tokenizer"
device : "cuda:0"
# Optimization settings
pool_size : 100
validation_perc : 0.2
num_gens_per_iter : 50
generation_batch_size : 32
num_mols : 3
num_similars : 5
# Generation settings
generation_temperature : [ 0.8 , 1.2 ]
generation_config :
max_new_tokens : 512
do_sample : true
top_k : 50
top_p : 0.95
# Strategy
strategy : [ "rej-sample-v2" ]
Example: QED Oracle
Here’s a simpler oracle that maximizes QED (drug-likeness):
from rdkit.Chem. QED import qed
class QED_Oracle :
def __init__ ( self , max_oracle_calls : int ):
self .max_oracle_calls = max_oracle_calls
self .mol_buffer = {}
self .max_possible_oracle_score = 1.0
self .takes_entry = True
def __call__ ( self , molecules : List[MoleculeEntry]):
oracle_scores = []
for molecule in molecules:
if molecule.smiles in self .mol_buffer:
oracle_scores.append( self .mol_buffer[molecule.smiles][ 0 ])
else :
try :
score = qed(molecule.mol)
except :
score = 0.0
self .mol_buffer[molecule.smiles] = [score, len ( self .mol_buffer) + 1 ]
oracle_scores.append(score)
return oracle_scores
def __len__ ( self ):
return len ( self .mol_buffer)
@ property
def budget ( self ):
return self .max_oracle_calls
@ property
def finish ( self ):
return len ( self .mol_buffer) >= self .max_oracle_calls
Best Practices
Always cache computed scores in mol_buffer to avoid redundant calculations: if molecule.smiles in self .mol_buffer:
return cached_score
else :
compute_and_cache()
Wrap scoring calculations in try-except blocks to handle invalid molecules: try :
score = calculate_property(molecule.mol)
except Exception as e:
print (e)
score = 0 # Return penalty score
Set max_possible_oracle_score to help the model understand the score range: # For QED (0-1 range)
self .max_possible_oracle_score = 1.0
# For custom scores, set an upper bound
self .max_possible_oracle_score = 800
Log intermediate results periodically to track progress: if len ( self .mol_buffer) % self .freq_log == 0 :
self .log_intermediate()
The takes_entry attribute determines whether __call__ receives MoleculeEntry objects (with .mol and .fingerprint attributes) or just SMILES strings.
Next Steps
Property Prediction Learn how to fine-tune models for property prediction
Benchmarking Evaluate your oracle’s performance