Skip to main content
Rejection sampling is an iterative fine-tuning strategy that generates molecules, scores them with an oracle, and fine-tunes the model on high-scoring examples.

Main Function

from chemlactica.rejection_sampling_ft import fine_tine

def fine_tine(
    from_pretrained: str,
    model_config_name: str,
    valid_data_dir: str,
    rounds: int,
    steps_per_round: int,
    eval_steps: int,
    save_steps: int,
    train_batch_size: int,
    max_learning_rate: float,
    experiment_name: str,
    checkpoints_root_dir: str,
    dataloader_num_workers: int,
    use_flash_attn: bool,
    gradient_accumulation_steps: int,
    gradient_checkpointing: bool,
    device: str,
    seed: int,
    track: bool = False,
    track_dir: str = None,
    check_reproducability: bool = False,
    valid_batch_size: int = None,
    profile: bool = False,
    profile_dir: str = None
)
Source: rejection_sampling_ft.py:57

Parameters

from_pretrained
str
required
Path to pretrained model checkpoint or Hugging Face model ID.
model_config_name
str
required
Model configuration name (e.g., “125m”, “1.3b”).
valid_data_dir
str
required
Directory containing validation data.
rounds
int
required
Number of rejection sampling rounds to perform.
steps_per_round
int
required
Training steps per rejection sampling round.
eval_steps
int
required
Evaluate model every N steps.
save_steps
int
required
Save checkpoint every N steps.
train_batch_size
int
required
Training batch size per device.
max_learning_rate
float
required
Maximum learning rate for polynomial decay schedule.
experiment_name
str
required
Name for this training run.
checkpoints_root_dir
str
required
Root directory for saving checkpoints.
dataloader_num_workers
int
required
Number of workers for data loading.
use_flash_attn
bool
required
Whether to use Flash Attention for efficient training.
gradient_accumulation_steps
int
required
Number of gradient accumulation steps.
gradient_checkpointing
bool
required
Whether to use gradient checkpointing to save memory.
device
str
required
Device to use for training (e.g., “cuda”, “cpu”).
seed
int
required
Random seed for reproducibility.
track
bool
default:"False"
Whether to enable experiment tracking (Aim).
track_dir
str
default:"None"
Directory for experiment tracking logs.

Rejection Sampling Workflow

The rejection sampling process follows these steps:
1

Generate candidates

Generate molecules using the current model with specified prompts.
2

Score with oracle

Evaluate generated molecules using a scoring function (e.g., QED, docking score).
3

Select top performers

Filter molecules based on score thresholds and similarity constraints.
4

Fine-tune on selected data

Train the model on high-scoring molecules for N steps.
5

Repeat

Continue for the specified number of rounds.

Generation Configuration

from chemlactica.generation.rejection_sampling_configs import GenerationConfig

config = GenerationConfig(
    max_new_tokens=128,
    temperature=1.0,
    do_sample=True,
    num_return_sequences=10,
    repetition_penalty=1.2,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id
)

GenerationConfig Parameters

max_new_tokens
int
default:"128"
Maximum number of new tokens to generate.
temperature
float
default:"1.0"
Sampling temperature for diversity.
do_sample
bool
default:"True"
Whether to use sampling (vs greedy decoding).
num_return_sequences
int
default:"1"
Number of sequences to generate per prompt.
repetition_penalty
float
default:"1.0"
Penalty for repeating tokens (greater than 1.0 discourages repetition).

Scoring Configuration

from chemlactica.generation.rejection_sampling_utils import ScoringConfig

scoring_config = ScoringConfig(
    metric="qed",  # or "tanimoto", "custom"
    threshold=0.9,
    similarity_threshold=0.4,
    reference_smiles="CC(=O)OC1=CC=CC=C1C(=O)O"
)

ScoringConfig Parameters

metric
str
required
Scoring metric to use:
  • "qed": Quantitative Estimate of Drug-likeness
  • "tanimoto": Tanimoto similarity to reference
  • "custom": Custom scoring function
threshold
float
required
Minimum score threshold for acceptance.
similarity_threshold
float
Minimum Tanimoto similarity to reference molecule (when using similarity constraints).
reference_smiles
str
Reference SMILES for similarity-constrained optimization.

Example Usage

import torch
from chemlactica.rejection_sampling_ft import fine_tine

# Configure rejection sampling fine-tuning
fine_tine(
    from_pretrained="yerevann/chemlactica-125m",
    model_config_name="125m",
    valid_data_dir="./validation_data",
    rounds=10,
    steps_per_round=100,
    eval_steps=50,
    save_steps=100,
    train_batch_size=8,
    max_learning_rate=5e-5,
    experiment_name="chemlactica_qed_optimization",
    checkpoints_root_dir="./checkpoints",
    dataloader_num_workers=4,
    use_flash_attn=True,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    device="cuda",
    seed=42,
    track=True,
    track_dir="./aim_logs"
)

Custom Trainer

The rejection sampling uses CustomIterativeSFTTrainer which extends the standard trainer:
from chemlactica.custom_trainer import CustomIterativeSFTTrainer

trainer = CustomIterativeSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    formatting_func=lambda x: x["text"],
    callbacks=callbacks
)

# Train for one round
trainer.train()

Learning Rate Schedule

Rejection sampling uses polynomial decay with warmup:
from transformers import get_polynomial_decay_schedule_with_warmup

scheduler = get_polynomial_decay_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
    power=1.0  # Linear decay
)

Data Format

Generated data is stored in JSONL format:
{"text": "</s>[SAS]2.5[/SAS][QED]0.85[/QED][START_SMILES]CC(=O)OC1=CC=CC=C1C(=O)O[END_SMILES]", "score": 0.85}
{"text": "</s>[MW]300[/MW][TPSA]50[/TPSA][START_SMILES]COc1ccc(C=O)cc1[END_SMILES]", "score": 0.78}

Monitoring Progress

Track rejection sampling with Aim:
fine_tune(
    ...,
    track=True,
    track_dir="./aim_logs"
)
View results:
aim up --repo ./aim_logs

Notes

Rejection sampling can significantly improve molecule generation quality by focusing the model on high-scoring examples.
Use appropriate score thresholds to avoid overfitting to a narrow region of chemical space. Monitor diversity metrics during training.
Start with a smaller number of rounds (5-10) and evaluate before scaling up. Excessive rounds can lead to mode collapse.

See Also

Build docs developers (and LLMs) love