Overview
Theoptimize() function is the core of ChemLactica’s molecular optimization pipeline. It orchestrates the iterative process of generating molecules, scoring them with an oracle function, and fine-tuning the model using rejection sampling.
Function Signature
Parameters
The pretrained language model for molecule generation. Typically a GPT-2 or OPT model fine-tuned on SMILES strings.
Tokenizer corresponding to the model, used to encode prompts and decode generated sequences.
Scoring function that evaluates generated molecules. Can accept either:
- A list of SMILES strings and return scores
- A list of
MoleculeEntryobjects iforacle.takes_entry = True
max_oracle_calls: Maximum number of molecules to evaluatefinish: Boolean indicating if optimization should stopmol_buffer: Dictionary tracking evaluated moleculesbudget: Total oracle call budget
Configuration dictionary containing optimization parameters:Required keys:
log_dir(str): Path to output log filepool_size(int): Maximum size of molecule poolvalidation_perc(float): Percentage of pool reserved for validation (0-1)generation_config(dict): Parameters for model.generate() includingmax_new_tokens,temperature, etc.generation_temperature(list): [min, max] temperature range for annealingstrategy(list): Optimization strategies, e.g.,["rej-sample-v2"]num_gens_per_iter(int): Number of unique molecules to generate per iterationgeneration_batch_size(int): Batch size for generationnum_mols(int): Number of molecules in each optimization entrynum_similars(int): Number of similar molecules to include in promptsmax_possible_oracle_score(float): Maximum achievable oracle scoresim_range(list): [min, max] similarity range for generation promptseos_token(str): End-of-sequence token for the model
rej-sample-v2):rej_sample_config(dict): Training configuration including:train_batch_size(int)gradient_accumulation_steps(int)num_train_epochs(int)train_tol_level(int): Tolerance iterations before triggering trainingrej_perc(float): Rejection sampling percentagepacking(bool): Whether to pack sequencesmax_seq_length(int): Maximum sequence length
Dictionary of additional molecular properties to include in prompts. Each property should have:
start_tag(str): Opening tag for the propertyend_tag(str): Closing tag for the propertycalculate_value(callable): Function to compute property value fromMoleculeEntryinfer_value(callable): Function to infer property value for generation prompts
Function to validate generated SMILES strings. Takes a SMILES string and returns
True if valid, False otherwise.Returns
The function does not return a value. Instead, it:- Writes optimization progress to the log file specified in
config["log_dir"] - Updates the model weights in-place through fine-tuning
- Maintains the oracle’s
mol_bufferwith all evaluated molecules
Behavior
Iterative Optimization Loop
- Generation Phase: Creates prompts from pool molecules and generates new candidates
- Scoring Phase: Evaluates unique molecules with the oracle function
- Pool Update: Adds high-scoring molecules to the pool with diversity filtering
- Fine-tuning Phase: Triggers model training when tolerance threshold is reached
Temperature Annealing
Generation temperature increases linearly fromgeneration_temperature[0] to generation_temperature[1] based on oracle budget consumption:
Training Triggers
Forrej-sample-v2 strategy, model fine-tuning occurs when:
tol_level >= config["rej_sample_config"]["train_tol_level"]tol_levelincrements each iteration without finding a new best moleculetol_levelresets to 0 after training or finding a new best molecule
Example Usage
Related Functions
- MoleculeEntry - Represents individual molecules
- OptimEntry - Represents optimization entries with multiple molecules
- Pool - Manages the pool of high-scoring molecules