Learn how to implement oracle functions to evaluate and guide molecular optimization
The oracle is the core component that evaluates generated molecules and guides the optimization process. It defines what makes a molecule “good” for your specific use case.
class CustomOracle: def __init__(self, ...): # Maximum number of oracle calls to make self.max_oracle_calls: int = ... # Frequency for logging intermediate results self.freq_log: int = ... # Buffer to track all unique molecules generated self.mol_buffer: Dict = ... # Maximum possible oracle score (upper bound) self.max_possible_oracle_score: float = ... # Optional: if True, __call__ receives MoleculeEntry objects # if False (default), __call__ receives SMILES strings self.takes_entry: bool = False def __call__(self, molecules): """ Evaluate and return oracle scores for molecules. Log intermediate results if necessary. Args: molecules: List of SMILES strings or MoleculeEntry objects (depending on self.takes_entry) Returns: List of float scores (same order as input) """ ... return oracle_scores @property def finish(self): """Specify the stopping condition for optimization.""" return len(self.mol_buffer) >= self.max_oracle_calls def __len__(self): """Return the number of molecules evaluated so far.""" return len(self.mol_buffer) @property def budget(self): """Return the maximum oracle calls budget.""" return self.max_oracle_calls
This example from mol_opt/example_run.py optimizes molecules for high TPSA (topological polar surface area) with molecular weight and ring constraints:
from typing import Listimport numpy as npfrom rdkit.Chem import rdMolDescriptorsfrom chemlactica.mol_opt.utils import MoleculeEntryclass TPSA_Weight_Oracle: def __init__(self, max_oracle_calls: int): self.max_oracle_calls = max_oracle_calls self.freq_log = 100 self.mol_buffer = {} self.max_possible_oracle_score = 800 # Request MoleculeEntry objects instead of SMILES strings self.takes_entry = True def __call__(self, molecules: List[MoleculeEntry]): """Evaluate molecules based on TPSA, weight, and ring constraints.""" oracle_scores = [] for molecule in molecules: # Check if already evaluated if self.mol_buffer.get(molecule.smiles): oracle_scores.append(sum(self.mol_buffer[molecule.smiles][0])) else: try: # Calculate TPSA as the base score tpsa = rdMolDescriptors.CalcTPSA(molecule.mol) oracle_score = tpsa # Apply weight constraint weight = rdMolDescriptors.CalcExactMolWt(molecule.mol) if weight >= 350: oracle_score = 0 # Apply ring constraint num_rings = rdMolDescriptors.CalcNumRings(molecule.mol) if num_rings < 2: oracle_score = 0 except Exception as e: print(e) oracle_score = 0 # Store in buffer self.mol_buffer[molecule.smiles] = [oracle_score, len(self.mol_buffer) + 1] # Log periodically if len(self.mol_buffer) % self.freq_log == 0: self.log_intermediate() oracle_scores.append(oracle_score) return oracle_scores def log_intermediate(self): """Log statistics of top molecules.""" scores = [v[0] for v in self.mol_buffer.values()][-self.max_oracle_calls:] scores_sorted = sorted(scores, reverse=True)[:100] n_calls = len(self.mol_buffer) score_avg_top1 = np.max(scores_sorted) score_avg_top10 = np.mean(scores_sorted[:10]) score_avg_top100 = np.mean(scores_sorted) print(f"{n_calls}/{self.max_oracle_calls} | " f'avg_top1: {score_avg_top1:.3f} | ' f'avg_top10: {score_avg_top10:.3f} | ' f'avg_top100: {score_avg_top100:.3f}') def __len__(self): return len(self.mol_buffer) @property def budget(self): return self.max_oracle_calls @property def finish(self): return len(self.mol_buffer) >= self.max_oracle_calls