Rejection sampling is an iterative fine-tuning strategy that generates molecules, scores them with an oracle, and fine-tunes the model on high-scoring examples.
Main Function
from chemlactica.rejection_sampling_ft import fine_tine
def fine_tine(
from_pretrained: str,
model_config_name: str,
valid_data_dir: str,
rounds: int,
steps_per_round: int,
eval_steps: int,
save_steps: int,
train_batch_size: int,
max_learning_rate: float,
experiment_name: str,
checkpoints_root_dir: str,
dataloader_num_workers: int,
use_flash_attn: bool,
gradient_accumulation_steps: int,
gradient_checkpointing: bool,
device: str,
seed: int,
track: bool = False,
track_dir: str = None,
check_reproducability: bool = False,
valid_batch_size: int = None,
profile: bool = False,
profile_dir: str = None
)
Source: rejection_sampling_ft.py:57
Parameters
Path to pretrained model checkpoint or Hugging Face model ID.
Model configuration name (e.g., “125m”, “1.3b”).
Directory containing validation data.
Number of rejection sampling rounds to perform.
Training steps per rejection sampling round.
Evaluate model every N steps.
Save checkpoint every N steps.
Training batch size per device.
Maximum learning rate for polynomial decay schedule.
Name for this training run.
Root directory for saving checkpoints.
Number of workers for data loading.
Whether to use Flash Attention for efficient training.
gradient_accumulation_steps
Number of gradient accumulation steps.
Whether to use gradient checkpointing to save memory.
Device to use for training (e.g., “cuda”, “cpu”).
Random seed for reproducibility.
Whether to enable experiment tracking (Aim).
Directory for experiment tracking logs.
Rejection Sampling Workflow
The rejection sampling process follows these steps:
Generate candidates
Generate molecules using the current model with specified prompts.
Score with oracle
Evaluate generated molecules using a scoring function (e.g., QED, docking score).
Select top performers
Filter molecules based on score thresholds and similarity constraints.
Fine-tune on selected data
Train the model on high-scoring molecules for N steps.
Repeat
Continue for the specified number of rounds.
Generation Configuration
from chemlactica.generation.rejection_sampling_configs import GenerationConfig
config = GenerationConfig(
max_new_tokens=128,
temperature=1.0,
do_sample=True,
num_return_sequences=10,
repetition_penalty=1.2,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
GenerationConfig Parameters
Maximum number of new tokens to generate.
Sampling temperature for diversity.
Whether to use sampling (vs greedy decoding).
Number of sequences to generate per prompt.
Penalty for repeating tokens (greater than 1.0 discourages repetition).
Scoring Configuration
from chemlactica.generation.rejection_sampling_utils import ScoringConfig
scoring_config = ScoringConfig(
metric="qed", # or "tanimoto", "custom"
threshold=0.9,
similarity_threshold=0.4,
reference_smiles="CC(=O)OC1=CC=CC=C1C(=O)O"
)
ScoringConfig Parameters
Scoring metric to use:
"qed": Quantitative Estimate of Drug-likeness
"tanimoto": Tanimoto similarity to reference
"custom": Custom scoring function
Minimum score threshold for acceptance.
Minimum Tanimoto similarity to reference molecule (when using similarity constraints).
Reference SMILES for similarity-constrained optimization.
Example Usage
import torch
from chemlactica.rejection_sampling_ft import fine_tine
# Configure rejection sampling fine-tuning
fine_tine(
from_pretrained="yerevann/chemlactica-125m",
model_config_name="125m",
valid_data_dir="./validation_data",
rounds=10,
steps_per_round=100,
eval_steps=50,
save_steps=100,
train_batch_size=8,
max_learning_rate=5e-5,
experiment_name="chemlactica_qed_optimization",
checkpoints_root_dir="./checkpoints",
dataloader_num_workers=4,
use_flash_attn=True,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
device="cuda",
seed=42,
track=True,
track_dir="./aim_logs"
)
Custom Trainer
The rejection sampling uses CustomIterativeSFTTrainer which extends the standard trainer:
from chemlactica.custom_trainer import CustomIterativeSFTTrainer
trainer = CustomIterativeSFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
formatting_func=lambda x: x["text"],
callbacks=callbacks
)
# Train for one round
trainer.train()
Learning Rate Schedule
Rejection sampling uses polynomial decay with warmup:
from transformers import get_polynomial_decay_schedule_with_warmup
scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
power=1.0 # Linear decay
)
Generated data is stored in JSONL format:
{"text": "</s>[SAS]2.5[/SAS][QED]0.85[/QED][START_SMILES]CC(=O)OC1=CC=CC=C1C(=O)O[END_SMILES]", "score": 0.85}
{"text": "</s>[MW]300[/MW][TPSA]50[/TPSA][START_SMILES]COc1ccc(C=O)cc1[END_SMILES]", "score": 0.78}
Monitoring Progress
Track rejection sampling with Aim:
fine_tune(
...,
track=True,
track_dir="./aim_logs"
)
View results:
Notes
Rejection sampling can significantly improve molecule generation quality by focusing the model on high-scoring examples.
Use appropriate score thresholds to avoid overfitting to a narrow region of chemical space. Monitor diversity metrics during training.
Start with a smaller number of rounds (5-10) and evaluate before scaling up. Excessive rounds can lead to mode collapse.
See Also