Skip to main content

Overview

Rejection sampling is an advanced fine-tuning strategy that iteratively generates and filters molecular candidates based on desired properties. This approach is particularly effective for molecular optimization tasks where you want to improve specific properties while maintaining structural similarity.

How It Works

1

Generate Candidates

The model generates multiple molecular candidates based on a lead molecule and similar structures
2

Filter by Criteria

Candidates are filtered based on:
  • QED (drug-likeness) score
  • Tanimoto similarity to lead molecule
  • Other custom properties
3

Create Training Samples

Accepted molecules are formatted into training samples with their properties
4

Fine-tune Model

The model is fine-tuned on the filtered samples
5

Iterate

Process repeats for multiple rounds, progressively improving the model

Quick Start

python chemlactica/rejection_sampling_ft.py \
  --from_pretrained facebook/galactica-125m \
  --model_config 125m \
  --valid_data_dir ./data/valid \
  --rounds 10 \
  --steps_per_round 100 \
  --train_batch_size 8 \
  --max_learning_rate 1.0e-4 \
  --eval_steps 50 \
  --save_steps 100 \
  --checkpoints_root_dir ./checkpoints/rejection_sampling \
  --experiment_name mol_optimization \
  --device cuda:0 \
  --seed 42

Command-Line Arguments

Required Arguments

--from_pretrained
string
required
Path to pretrained model or model identifier
--model_config
string
required
Model configuration name (e.g., 125m, 1.3b)
--valid_data_dir
string
required
Directory containing validation data in JSONL format
--rounds
integer
required
Number of rejection sampling rounds to perform
--steps_per_round
integer
required
Number of training steps per round (also number of samples generated)
--train_batch_size
integer
required
Training batch size
--max_learning_rate
float
required
Maximum learning rate for training
--eval_steps
integer
required
Steps between evaluation runs
--save_steps
integer
required
Steps between checkpoint saves
--checkpoints_root_dir
string
required
Root directory for saving checkpoints
--device
string
required
Device to use for generation (e.g., cuda:0, cpu)

Optional Arguments

--experiment_name
string
default:"none"
Name for the experiment
--valid_batch_size
integer
Validation batch size (defaults to training batch size)
--dataloader_num_workers
integer
default:"0"
Number of dataloader workers
--gradient_accumulation_steps
integer
default:"1"
Gradient accumulation steps
--gradient_checkpointing
boolean
default:"false"
Enable gradient checkpointing
--flash_attn
boolean
default:"false"
Use Flash Attention
--track
boolean
default:"true"
Enable Aim tracking
--track_dir
string
Directory for tracking data
--seed
integer
default:"42"
Random seed for reproducibility

Generation Configuration

The rejection sampling process uses two sets of generation parameters defined in rejection_sampling_configs.py:

Similar Molecule Generation

sample_gen_args = {
    "max_new_tokens": 50,
    "temperature": 1.0,
    "repetition_penalty": 1.0,
    "do_sample": True,
    "eos_token_id": 2
}
max_new_tokens
integer
default:"50"
Maximum tokens to generate for similar molecules
temperature
float
default:"1.0"
Sampling temperature (higher = more diverse)
repetition_penalty
float
default:"1.0"
Penalty for repeating tokens

Target Molecule Generation

rej_sample_args = {
    "max_new_tokens": 300,
    "temperature": 1.0,
    "repetition_penalty": 1.0,
    "do_sample": True,
    "num_return_sequences": 20,
    "eos_token_id": 20,
}
num_return_sequences
integer
default:"20"
Number of candidate molecules to generate per prompt
max_new_tokens
integer
default:"300"
Maximum tokens for target molecule generation

Molecular Scoring

Molecules are evaluated using multiple metrics:

QED Score

# From rejection_sampling_utils.py:55
def compute_qed(smiles: str):
    return qed(Chem.MolFromSmiles(smiles))
Quantitative Estimation of Drug-likeness (0-1 scale, higher is better).

Tanimoto Similarity

# From rejection_sampling_utils.py:48
def tanimoto_dist_func(
    smiles1: str, 
    smiles2: str, 
    fingerprint: FingerprintType=FingerprintType.Morgan
):
    return DataStructs.TanimotoSimilarity(
        get_morgan_fingerprint(smiles1),
        get_morgan_fingerprint(smiles2)
    )
Measures structural similarity between molecules (0-1 scale).

Combined Score

# From rejection_sampling_utils.py:74
def score(self):
    return self.morgan_sim_to_lead + self.qed
Molecules are ranked by the sum of similarity and QED scores.

Optimization Criteria

The default optimization targets molecules with:
# From rejection_sampling_ft.py:266-269
optimized_molecule_mask = np.bitwise_and(
    df_samples["qed"].values >= 0.9,
    df_samples["morgan_sim_to_lead"].values >= 0.4,
)
  • QED ≥ 0.9: High drug-likeness
  • Tanimoto similarity ≥ 0.4: Maintains structural features of lead

Prompt Format

The rejection sampling generates prompts in this format:
</s>[SIMILAR]{mol1_smiles} {similarity}[/SIMILAR][SIMILAR]{mol2_smiles} {similarity}[/SIMILAR][QED]{target_qed}[/QED][START_SMILES]{target_smiles}[END_SMILES]</s>

Example

</s>[SIMILAR]CC(C)Cc1ccc(C)cc1 0.92[/SIMILAR][SIMILAR]c1ccc2[nH]ccc2c1 0.88[/SIMILAR][QED]0.95[/QED][START_SMILES]CC(C)Cc1ccc2[nH]ccc2c1[END_SMILES]</s>

Training Process

Round-by-Round Workflow

For each round (from rejection_sampling_ft.py:235-299):
1

Load Generator Model

if trainer.state.global_step == 0:
    generator_checkpoint_path = from_pretrained
else:
    generator_checkpoint_path = os.path.join(
        training_args.output_dir, 
        f"checkpoint-{trainer.state.global_step}"
    )
2

Generate Samples

train_ds_name = generate_dataset(
    checkpoint_path=generator_checkpoint_path,
    run_hash=experiment_hash,
    round=i,
    num_samples=steps_per_round,
    max_similars_in_prompt=5,
    lead_molecule=lead_molecule,
    use_flash_attn=use_flash_attn,
    device=device,
    seed=seed,
)
3

Filter and Score

df_samples = pd.read_csv(train_ds_name)
optimized_molecule_mask = np.bitwise_and(
    df_samples["qed"].values >= 0.9,
    df_samples["morgan_sim_to_lead"].values >= 0.4,
)
4

Train on Samples

rej_sampled_samples = list(df_samples["samples"].values)
random.shuffle(rej_sampled_samples)
for batch_of_texts in batches:
    trainer.step(texts=batch_of_texts)
5

Save Checkpoint

trainer._save_checkpoint(model=None, trial=None)

Learning Rate Schedule

Rejection sampling uses polynomial decay with warmup:
# From rejection_sampling_ft.py:157-163
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
    optimizer,
    num_warmup_steps=train_config["warmup_steps"],
    num_training_steps=steps_per_round * rounds,
    lr_end=0.1 * train_config["max_learning_rate"],
    power=1.0,
)
This creates a linear decay from max LR to 10% of max LR over all rounds.

Example: Lead Optimization

Define Lead Molecule

# From rejection_sampling_ft.py:245-247
lead_molecule = "c1ccc(-c2cc(N3C[C@H]4[C@@H]5CC[C@@H](O5)[C@H]4C3)c3ccccc3[nH+]2)cc1"

Run Optimization

python chemlactica/rejection_sampling_ft.py \
  --from_pretrained ./checkpoints/pretrained/galactica-125m \
  --model_config 125m \
  --valid_data_dir ./data/valid \
  --rounds 20 \
  --steps_per_round 200 \
  --train_batch_size 8 \
  --max_learning_rate 5.0e-5 \
  --eval_steps 100 \
  --save_steps 200 \
  --checkpoints_root_dir ./checkpoints/rejection_sampling \
  --experiment_name lead_optimization \
  --dataloader_num_workers 4 \
  --gradient_accumulation_steps 2 \
  --device cuda:0 \
  --flash_attn \
  --track \
  --track_dir ./aim_logs \
  --seed 42

Monitor Progress

The training logs will show:
---------Rej Sampling ROUND 1---------
Generator model: ./checkpoints/pretrained/galactica-125m
Found optimized molecule: True
Successful rounds find molecules meeting the optimization criteria (QED ≥ 0.9, similarity ≥ 0.4).

Advanced Configuration

Custom Similarity Thresholds

Modify the prompt generation in generate_dataset:
# Adjust similarity range
morgan_sim_to_lead=random.uniform(0.85, 0.95)  # Stricter similarity

Custom QED Targets

# Target higher QED values
input_text += f"[QED]{random.uniform(0.95, 0.99):.2f}[/QED]"

Multiple Lead Molecules

Modify the training loop to alternate between different leads:
lead_molecules = ["SMILES1", "SMILES2", "SMILES3"]
for i, lead in enumerate(cycle(lead_molecules)):
    train_ds_name = generate_dataset(
        lead_molecule=lead,
        # ... other args
    )

Data Storage

Generated samples are saved as CSV files:
# From rejection_sampling_utils.py:110-113
base_path = f"/nfs/dgx/raid/chem/data/rej_sampling_data/{run_hash}"
ds_file_name = f'{base_path}/round:{round}_hash:{model_name[0][:6]}_step:{model_name[1].split("-")[-1]}_date:{formatted_date_time}.csv'
Each CSV contains:
  • samples: Full formatted training sample
  • smiles: Generated SMILES string
  • qed: QED score
  • morgan_sim_to_lead: Tanimoto similarity (Morgan fingerprint)

Performance Tips

  • Use --flash_attn for faster generation
  • Adjust num_return_sequences to balance diversity and speed
  • Use GPU for generation (--device cuda:0)
  • Consider batch generation for multiple leads
  • Increase max_similars_in_prompt (default 5) for more context
  • Adjust temperature for more/less diversity
  • Filter invalid SMILES early to save computation
  • Use canonical and Kekulé SMILES for consistency
# From rejection_sampling_ft.py:261-263
gc.collect()
torch.cuda.empty_cache()
Clear cache between rounds to prevent OOM

Troubleshooting

No Optimized Molecules Found

  • Relax optimization criteria (lower QED/similarity thresholds)
  • Increase num_return_sequences for more candidates
  • Try different lead molecules
  • Increase temperature for more diversity

Invalid SMILES Generated

  • Check tokenizer is properly configured
  • Verify model was trained on SMILES data
  • Ensure proper formatting of start/end tokens
  • Consider adding SMILES validation callback

Training Instability

  • Reduce learning rate
  • Increase warmup steps
  • Use gradient clipping (controlled by global_gradient_norm)
  • Monitor gradient norms in Aim

Next Steps

Pretraining

Learn about pretraining models from scratch

Configuration

Explore all training configuration options

Build docs developers (and LLMs) love