Overview
Rejection sampling is an advanced fine-tuning strategy that iteratively generates and filters molecular candidates based on desired properties. This approach is particularly effective for molecular optimization tasks where you want to improve specific properties while maintaining structural similarity.
How It Works
Generate Candidates
The model generates multiple molecular candidates based on a lead molecule and similar structures
Filter by Criteria
Candidates are filtered based on:
QED (drug-likeness) score
Tanimoto similarity to lead molecule
Other custom properties
Create Training Samples
Accepted molecules are formatted into training samples with their properties
Fine-tune Model
The model is fine-tuned on the filtered samples
Iterate
Process repeats for multiple rounds, progressively improving the model
Quick Start
python chemlactica/rejection_sampling_ft.py \
--from_pretrained facebook/galactica-125m \
--model_config 125m \
--valid_data_dir ./data/valid \
--rounds 10 \
--steps_per_round 100 \
--train_batch_size 8 \
--max_learning_rate 1.0e-4 \
--eval_steps 50 \
--save_steps 100 \
--checkpoints_root_dir ./checkpoints/rejection_sampling \
--experiment_name mol_optimization \
--device cuda:0 \
--seed 42
Command-Line Arguments
Required Arguments
Path to pretrained model or model identifier
Model configuration name (e.g., 125m, 1.3b)
Directory containing validation data in JSONL format
Number of rejection sampling rounds to perform
Number of training steps per round (also number of samples generated)
Maximum learning rate for training
Steps between evaluation runs
Steps between checkpoint saves
Root directory for saving checkpoints
Device to use for generation (e.g., cuda:0, cpu)
Optional Arguments
Validation batch size (defaults to training batch size)
Number of dataloader workers
--gradient_accumulation_steps
Gradient accumulation steps
Enable gradient checkpointing
Directory for tracking data
Random seed for reproducibility
Generation Configuration
The rejection sampling process uses two sets of generation parameters defined in rejection_sampling_configs.py:
Similar Molecule Generation
sample_gen_args = {
"max_new_tokens" : 50 ,
"temperature" : 1.0 ,
"repetition_penalty" : 1.0 ,
"do_sample" : True ,
"eos_token_id" : 2
}
Maximum tokens to generate for similar molecules
Sampling temperature (higher = more diverse)
Penalty for repeating tokens
Target Molecule Generation
rej_sample_args = {
"max_new_tokens" : 300 ,
"temperature" : 1.0 ,
"repetition_penalty" : 1.0 ,
"do_sample" : True ,
"num_return_sequences" : 20 ,
"eos_token_id" : 20 ,
}
Number of candidate molecules to generate per prompt
Maximum tokens for target molecule generation
Molecular Scoring
Molecules are evaluated using multiple metrics:
QED Score
# From rejection_sampling_utils.py:55
def compute_qed ( smiles : str ):
return qed(Chem.MolFromSmiles(smiles))
Quantitative Estimation of Drug-likeness (0-1 scale, higher is better).
Tanimoto Similarity
# From rejection_sampling_utils.py:48
def tanimoto_dist_func (
smiles1 : str ,
smiles2 : str ,
fingerprint : FingerprintType = FingerprintType.Morgan
):
return DataStructs.TanimotoSimilarity(
get_morgan_fingerprint(smiles1),
get_morgan_fingerprint(smiles2)
)
Measures structural similarity between molecules (0-1 scale).
Combined Score
# From rejection_sampling_utils.py:74
def score ( self ):
return self .morgan_sim_to_lead + self .qed
Molecules are ranked by the sum of similarity and QED scores.
Optimization Criteria
The default optimization targets molecules with:
# From rejection_sampling_ft.py:266-269
optimized_molecule_mask = np.bitwise_and(
df_samples[ "qed" ].values >= 0.9 ,
df_samples[ "morgan_sim_to_lead" ].values >= 0.4 ,
)
QED ≥ 0.9 : High drug-likeness
Tanimoto similarity ≥ 0.4 : Maintains structural features of lead
The rejection sampling generates prompts in this format:
</s>[SIMILAR]{mol1_smiles} {similarity}[/SIMILAR][SIMILAR]{mol2_smiles} {similarity}[/SIMILAR][QED]{target_qed}[/QED][START_SMILES]{target_smiles}[END_SMILES]</s>
Example
</s>[SIMILAR]CC(C)Cc1ccc(C)cc1 0.92[/SIMILAR][SIMILAR]c1ccc2[nH]ccc2c1 0.88[/SIMILAR][QED]0.95[/QED][START_SMILES]CC(C)Cc1ccc2[nH]ccc2c1[END_SMILES]</s>
Training Process
Round-by-Round Workflow
For each round (from rejection_sampling_ft.py:235-299):
Load Generator Model
if trainer.state.global_step == 0 :
generator_checkpoint_path = from_pretrained
else :
generator_checkpoint_path = os.path.join(
training_args.output_dir,
f "checkpoint- { trainer.state.global_step } "
)
Generate Samples
train_ds_name = generate_dataset(
checkpoint_path = generator_checkpoint_path,
run_hash = experiment_hash,
round = i,
num_samples = steps_per_round,
max_similars_in_prompt = 5 ,
lead_molecule = lead_molecule,
use_flash_attn = use_flash_attn,
device = device,
seed = seed,
)
Filter and Score
df_samples = pd.read_csv(train_ds_name)
optimized_molecule_mask = np.bitwise_and(
df_samples[ "qed" ].values >= 0.9 ,
df_samples[ "morgan_sim_to_lead" ].values >= 0.4 ,
)
Train on Samples
rej_sampled_samples = list (df_samples[ "samples" ].values)
random.shuffle(rej_sampled_samples)
for batch_of_texts in batches:
trainer.step( texts = batch_of_texts)
Save Checkpoint
trainer._save_checkpoint( model = None , trial = None )
Learning Rate Schedule
Rejection sampling uses polynomial decay with warmup:
# From rejection_sampling_ft.py:157-163
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps = train_config[ "warmup_steps" ],
num_training_steps = steps_per_round * rounds,
lr_end = 0.1 * train_config[ "max_learning_rate" ],
power = 1.0 ,
)
This creates a linear decay from max LR to 10% of max LR over all rounds.
Example: Lead Optimization
Define Lead Molecule
# From rejection_sampling_ft.py:245-247
lead_molecule = "c1ccc(-c2cc(N3C[C@H]4[C@@H]5CC[C@@H](O5)[C@H]4C3)c3ccccc3[nH+]2)cc1"
Run Optimization
python chemlactica/rejection_sampling_ft.py \
--from_pretrained ./checkpoints/pretrained/galactica-125m \
--model_config 125m \
--valid_data_dir ./data/valid \
--rounds 20 \
--steps_per_round 200 \
--train_batch_size 8 \
--max_learning_rate 5.0e-5 \
--eval_steps 100 \
--save_steps 200 \
--checkpoints_root_dir ./checkpoints/rejection_sampling \
--experiment_name lead_optimization \
--dataloader_num_workers 4 \
--gradient_accumulation_steps 2 \
--device cuda:0 \
--flash_attn \
--track \
--track_dir ./aim_logs \
--seed 42
Monitor Progress
The training logs will show:
---------Rej Sampling ROUND 1---------
Generator model: ./checkpoints/pretrained/galactica-125m
Found optimized molecule: True
Successful rounds find molecules meeting the optimization criteria (QED ≥ 0.9, similarity ≥ 0.4).
Advanced Configuration
Custom Similarity Thresholds
Modify the prompt generation in generate_dataset:
# Adjust similarity range
morgan_sim_to_lead = random.uniform( 0.85 , 0.95 ) # Stricter similarity
Custom QED Targets
# Target higher QED values
input_text += f "[QED] { random.uniform( 0.95 , 0.99 ) :.2f} [/QED]"
Multiple Lead Molecules
Modify the training loop to alternate between different leads:
lead_molecules = [ "SMILES1" , "SMILES2" , "SMILES3" ]
for i, lead in enumerate (cycle(lead_molecules)):
train_ds_name = generate_dataset(
lead_molecule = lead,
# ... other args
)
Data Storage
Generated samples are saved as CSV files:
# From rejection_sampling_utils.py:110-113
base_path = f "/nfs/dgx/raid/chem/data/rej_sampling_data/ { run_hash } "
ds_file_name = f ' { base_path } /round: { round } _hash: { model_name[ 0 ][: 6 ] } _step: { model_name[ 1 ].split( "-" )[ - 1 ] } _date: { formatted_date_time } .csv'
Each CSV contains:
samples: Full formatted training sample
smiles: Generated SMILES string
qed: QED score
morgan_sim_to_lead: Tanimoto similarity (Morgan fingerprint)
Use --flash_attn for faster generation
Adjust num_return_sequences to balance diversity and speed
Use GPU for generation (--device cuda:0)
Consider batch generation for multiple leads
Increase max_similars_in_prompt (default 5) for more context
Adjust temperature for more/less diversity
Filter invalid SMILES early to save computation
Use canonical and Kekulé SMILES for consistency
# From rejection_sampling_ft.py:261-263
gc.collect()
torch.cuda.empty_cache()
Clear cache between rounds to prevent OOM
Troubleshooting
No Optimized Molecules Found
Relax optimization criteria (lower QED/similarity thresholds)
Increase num_return_sequences for more candidates
Try different lead molecules
Increase temperature for more diversity
Invalid SMILES Generated
Check tokenizer is properly configured
Verify model was trained on SMILES data
Ensure proper formatting of start/end tokens
Consider adding SMILES validation callback
Training Instability
Reduce learning rate
Increase warmup steps
Use gradient clipping (controlled by global_gradient_norm)
Monitor gradient norms in Aim
Next Steps
Pretraining Learn about pretraining models from scratch
Configuration Explore all training configuration options