Overview
ChemLactica supports two types of fine-tuning:
Supervised Fine-Tuning (SFT) : Train on labeled molecular property data
Rejection Sampling : Optimize for specific properties during molecule generation
This guide covers both approaches and when to use each.
Supervised Fine-Tuning (SFT)
When to Use SFT
Use supervised fine-tuning when you have:
Labeled dataset with molecular properties
Want to predict specific properties (QED, activity, etc.)
Need the model to learn property-structure relationships
Prepare your data in the ChemLactica format with property tags:
</s>[SIMILAR]CCO 0.95[/SIMILAR][QED]0.87[/QED][START_SMILES]CC(C)O[END_SMILES]</s>
</s>[SIMILAR]c1ccccc1 0.82[/SIMILAR][PROPERTY]activity 0.92[/PROPERTY][START_SMILES]c1ccc(O)cc1[END_SMILES]</s>
Each training sample should include:
EOS token </s> as separator
Similar molecules with similarity scores
Property values in [PROPERTY]name value[/PROPERTY] format
Target SMILES in [START_SMILES]...[END_SMILES]
Configuration
Create an SFT configuration file:
sft_config.yaml
sft_train.py
train_config :
adam_beta1 : 0.9
adam_beta2 : 0.95
max_learning_rate : 1.0e-4
warmup_steps : 0
weight_decay : 0.1
global_gradient_norm : 1.0
bf16 : true
evaluation_strategy : "steps"
save_total_limit : 4
model_config :
n_heads : 12
n_layers : 12
block_size : 2048
vocab_size : 50000
separator_token : "</s>"
tokenizer_path : "./chemlactica/tokenizer/ChemLacticaTokenizer66"
sft_config :
packing : true
max_seq_length : 512
neftune_noise_alpha : 5
Key SFT Parameters
Pack multiple short sequences into single examples for efficiency
Maximum sequence length for training samples
NEFTune noise parameter for improved generalization
Rejection Sampling Fine-Tuning
Overview
Rejection sampling optimizes the model during molecule generation by:
Generating candidate molecules
Scoring them with an oracle
Fine-tuning on high-scoring examples
This approach is used in the optimization loop (see Custom Oracles ).
Configuration
Rejection sampling is configured within the optimization config:
strategy : [ "rej-sample-v2" ]
rej_sample_config :
# Training parameters
train_batch_size : 8
gradient_accumulation_steps : 4
num_train_epochs : 3
train_tol_level : 2 # Train after this many iterations without improvement
# Learning rate
max_learning_rate : 1.0e-5
adam_beta1 : 0.9
adam_beta2 : 0.95
warmup_steps : 0
lr_end : 1.0e-6
weight_decay : 0.1
global_gradient_norm : 1.0
# Data handling
packing : false
max_seq_length : 2048
dataloader_num_workers : 4
# Checkpointing
checkpoints_dir : "./checkpoints"
How It Works
Generate Molecules
The model generates molecules based on similar molecules and desired properties: # Optimization creates prompts like:
prompts = [
optim_entry.to_prompt(
is_generation = True ,
include_oracle_score = True ,
config = config,
max_score = max_score
)
for optim_entry in optim_entries
]
Score with Oracle
Generated molecules are evaluated: oracle_scores = oracle([entry.last_entry for entry in optim_entries])
Build Training Pool
High-scoring molecules are added to a pool: pool.add( list (iter_unique_optim_entries.values()))
train_entries, validation_entries = pool.get_train_valid_entries()
Fine-tune Model
When improvement stalls (tol_level >= train_tol_level), fine-tune on the pool: trainer = SFTTrainer(
model = model,
train_dataset = train_dataset,
eval_dataset = validation_dataset,
formatting_func = lambda x : x[ "sample" ],
args = training_args,
packing = config[ "rej_sample_config" ][ "packing" ],
tokenizer = tokenizer,
max_seq_length = config[ "rej_sample_config" ][ "max_seq_length" ],
callbacks = [model_selection_callback],
optimizers = [optimizer, lr_scheduler]
)
trainer.train()
Pool Management
The pool maintains diversity and quality:
pool = Pool(
size = config[ "pool_size" ], # e.g., 100
validation_perc = config[ "validation_perc" ] # e.g., 0.2
)
Molecules are sorted by score
Duplicates and highly similar molecules are removed (Tanimoto > threshold)
Top 20% are reserved for validation
Remaining 80% used for training
Training Arguments
Common Parameters
Peak learning rate for training
Number of warmup steps for learning rate scheduler
Weight decay for AdamW optimizer
gradient_accumulation_steps
Accumulate gradients over multiple batches
Use bfloat16 mixed precision training
Optimizer Configuration
ChemLactica uses AdamW with polynomial decay:
from transformers import get_polynomial_decay_schedule_with_warmup
import torch
optimizer = torch.optim.AdamW(
model.parameters(),
lr = config[ "max_learning_rate" ],
betas = [config[ "adam_beta1" ], config[ "adam_beta2" ]],
weight_decay = config[ "weight_decay" ]
)
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps = config[ "warmup_steps" ],
num_training_steps = max_train_steps,
lr_end = config[ "lr_end" ],
power = 1.0
)
Callbacks
Model Selection Callback
Saves the best model based on validation loss:
from chemlactica.mol_opt.tunning import CustomModelSelectionCallback
model_selection_callback = CustomModelSelectionCallback()
# After training
model.load_state_dict(model_selection_callback.best_model_state_dict)
print ( f "Best validation loss: { model_selection_callback.best_validation_loss } " )
Early Stopping Callback
Stops training when validation loss stops improving:
from chemlactica.mol_opt.tunning import CustomEarlyStopCallback
early_stopping = CustomEarlyStopCallback(
early_stopping_patience = 3 ,
early_stopping_threshold = 0.0001
)
Example: Complete SFT Workflow
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import Dataset
import pandas as pd
# Load configuration
config = yaml.safe_load( open ( "sft_config.yaml" ))
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"path/to/pretrained/model" ,
torch_dtype = torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(config[ "tokenizer_path" ])
# Prepare dataset
df = pd.read_csv( "property_data.csv" )
train_samples = [
f "</s>[SIMILAR] { row[ 'similar_smiles' ] } { row[ 'similarity' ] :.2f} [/SIMILAR]"
f "[PROPERTY]activity { row[ 'activity' ] :.2f} [/PROPERTY]"
f "[START_SMILES] { row[ 'smiles' ] } [END_SMILES]</s>"
for _, row in df.iterrows()
]
train_dataset = Dataset.from_dict({ "sample" : train_samples})
# Configure response template for loss masking
response_template = tokenizer.encode( "[PROPERTY]activity" )
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)
# Training arguments
training_args = TrainingArguments(
output_dir = "./sft_checkpoints" ,
per_device_train_batch_size = 8 ,
gradient_accumulation_steps = 4 ,
num_train_epochs = 3 ,
learning_rate = 1e-4 ,
bf16 = True ,
logging_steps = 10 ,
evaluation_strategy = "steps" ,
eval_steps = 100 ,
save_steps = 100 ,
save_total_limit = 3
)
# Create trainer
trainer = SFTTrainer(
model = model,
train_dataset = train_dataset,
formatting_func = lambda x : x[ "sample" ],
args = training_args,
packing = True ,
tokenizer = tokenizer,
max_seq_length = 512 ,
data_collator = collator,
neftune_noise_alpha = 5
)
# Train
trainer.train()
trainer.save_model( "./sft_final_model" )
Best Practices
Ensure property values are normalized to appropriate ranges
Include diverse molecular scaffolds in training data
Balance the dataset across property value ranges
Validate SMILES strings before training
Start with smaller learning rates (1e-5 to 1e-4) for fine-tuning
Use warmup steps (10-20% of training steps)
Monitor validation loss to prevent overfitting
Adjust neftune_noise_alpha for better generalization
Set train_tol_level based on oracle budget (typically 2-3)
Use larger pool sizes (100-200) for better diversity
Keep validation_perc around 0.2 for reliable evaluation
Save checkpoints periodically during long optimizations
Track top-1, top-10, top-100 scores during optimization
Monitor validation loss during fine-tuning
Check molecular validity and diversity of generated molecules
Use TensorBoard or Weights & Biases for visualization
When using rejection sampling, ensure your oracle has appropriate caching to avoid redundant calculations. The optimization process may evaluate the same molecule multiple times.
Next Steps
Custom Oracles Build custom scoring functions for optimization
Benchmarking Evaluate model performance on standard benchmarks