Skip to main content

Overview

ChemLactica supports two types of fine-tuning:
  • Supervised Fine-Tuning (SFT): Train on labeled molecular property data
  • Rejection Sampling: Optimize for specific properties during molecule generation
This guide covers both approaches and when to use each.

Supervised Fine-Tuning (SFT)

When to Use SFT

Use supervised fine-tuning when you have:
  • Labeled dataset with molecular properties
  • Want to predict specific properties (QED, activity, etc.)
  • Need the model to learn property-structure relationships

Dataset Format

Prepare your data in the ChemLactica format with property tags:
</s>[SIMILAR]CCO 0.95[/SIMILAR][QED]0.87[/QED][START_SMILES]CC(C)O[END_SMILES]</s>
</s>[SIMILAR]c1ccccc1 0.82[/SIMILAR][PROPERTY]activity 0.92[/PROPERTY][START_SMILES]c1ccc(O)cc1[END_SMILES]</s>
Each training sample should include:
  • EOS token </s> as separator
  • Similar molecules with similarity scores
  • Property values in [PROPERTY]name value[/PROPERTY] format
  • Target SMILES in [START_SMILES]...[END_SMILES]

Configuration

Create an SFT configuration file:
train_config:
  adam_beta1: 0.9
  adam_beta2: 0.95
  max_learning_rate: 1.0e-4
  warmup_steps: 0
  weight_decay: 0.1
  global_gradient_norm: 1.0
  bf16: true
  evaluation_strategy: "steps"
  save_total_limit: 4

model_config:
  n_heads: 12
  n_layers: 12
  block_size: 2048
  vocab_size: 50000
  separator_token: "</s>"
  tokenizer_path: "./chemlactica/tokenizer/ChemLacticaTokenizer66"

sft_config:
  packing: true
  max_seq_length: 512
  neftune_noise_alpha: 5

Key SFT Parameters

packing
boolean
default:"true"
Pack multiple short sequences into single examples for efficiency
max_seq_length
integer
default:"512"
Maximum sequence length for training samples
neftune_noise_alpha
integer
default:"5"
NEFTune noise parameter for improved generalization

Rejection Sampling Fine-Tuning

Overview

Rejection sampling optimizes the model during molecule generation by:
  1. Generating candidate molecules
  2. Scoring them with an oracle
  3. Fine-tuning on high-scoring examples
This approach is used in the optimization loop (see Custom Oracles).

Configuration

Rejection sampling is configured within the optimization config:
optimization_config.yaml
strategy: ["rej-sample-v2"]

rej_sample_config:
  # Training parameters
  train_batch_size: 8
  gradient_accumulation_steps: 4
  num_train_epochs: 3
  train_tol_level: 2  # Train after this many iterations without improvement
  
  # Learning rate
  max_learning_rate: 1.0e-5
  adam_beta1: 0.9
  adam_beta2: 0.95
  warmup_steps: 0
  lr_end: 1.0e-6
  weight_decay: 0.1
  global_gradient_norm: 1.0
  
  # Data handling
  packing: false
  max_seq_length: 2048
  dataloader_num_workers: 4
  
  # Checkpointing
  checkpoints_dir: "./checkpoints"

How It Works

1

Generate Molecules

The model generates molecules based on similar molecules and desired properties:
# Optimization creates prompts like:
prompts = [
    optim_entry.to_prompt(
        is_generation=True,
        include_oracle_score=True,
        config=config,
        max_score=max_score
    )
    for optim_entry in optim_entries
]
2

Score with Oracle

Generated molecules are evaluated:
oracle_scores = oracle([entry.last_entry for entry in optim_entries])
3

Build Training Pool

High-scoring molecules are added to a pool:
pool.add(list(iter_unique_optim_entries.values()))
train_entries, validation_entries = pool.get_train_valid_entries()
4

Fine-tune Model

When improvement stalls (tol_level >= train_tol_level), fine-tune on the pool:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    formatting_func=lambda x: x["sample"],
    args=training_args,
    packing=config["rej_sample_config"]["packing"],
    tokenizer=tokenizer,
    max_seq_length=config["rej_sample_config"]["max_seq_length"],
    callbacks=[model_selection_callback],
    optimizers=[optimizer, lr_scheduler]
)
trainer.train()

Pool Management

The pool maintains diversity and quality:
pool = Pool(
    size=config["pool_size"],  # e.g., 100
    validation_perc=config["validation_perc"]  # e.g., 0.2
)
  • Molecules are sorted by score
  • Duplicates and highly similar molecules are removed (Tanimoto > threshold)
  • Top 20% are reserved for validation
  • Remaining 80% used for training

Training Arguments

Common Parameters

learning_rate
float
default:"1e-4"
Peak learning rate for training
warmup_steps
integer
default:"500"
Number of warmup steps for learning rate scheduler
weight_decay
float
default:"0.1"
Weight decay for AdamW optimizer
gradient_accumulation_steps
integer
default:"1"
Accumulate gradients over multiple batches
bf16
boolean
default:"true"
Use bfloat16 mixed precision training

Optimizer Configuration

ChemLactica uses AdamW with polynomial decay:
from transformers import get_polynomial_decay_schedule_with_warmup
import torch

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config["max_learning_rate"],
    betas=[config["adam_beta1"], config["adam_beta2"]],
    weight_decay=config["weight_decay"]
)

lr_scheduler = get_polynomial_decay_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config["warmup_steps"],
    num_training_steps=max_train_steps,
    lr_end=config["lr_end"],
    power=1.0
)

Callbacks

Model Selection Callback

Saves the best model based on validation loss:
from chemlactica.mol_opt.tunning import CustomModelSelectionCallback

model_selection_callback = CustomModelSelectionCallback()

# After training
model.load_state_dict(model_selection_callback.best_model_state_dict)
print(f"Best validation loss: {model_selection_callback.best_validation_loss}")

Early Stopping Callback

Stops training when validation loss stops improving:
from chemlactica.mol_opt.tunning import CustomEarlyStopCallback

early_stopping = CustomEarlyStopCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.0001
)

Example: Complete SFT Workflow

import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import Dataset
import pandas as pd

# Load configuration
config = yaml.safe_load(open("sft_config.yaml"))

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "path/to/pretrained/model",
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"])

# Prepare dataset
df = pd.read_csv("property_data.csv")
train_samples = [
    f"</s>[SIMILAR]{row['similar_smiles']} {row['similarity']:.2f}[/SIMILAR]"
    f"[PROPERTY]activity {row['activity']:.2f}[/PROPERTY]"
    f"[START_SMILES]{row['smiles']}[END_SMILES]</s>"
    for _, row in df.iterrows()
]

train_dataset = Dataset.from_dict({"sample": train_samples})

# Configure response template for loss masking
response_template = tokenizer.encode("[PROPERTY]activity")
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

# Training arguments
training_args = TrainingArguments(
    output_dir="./sft_checkpoints",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=1e-4,
    bf16=True,
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=3
)

# Create trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    formatting_func=lambda x: x["sample"],
    args=training_args,
    packing=True,
    tokenizer=tokenizer,
    max_seq_length=512,
    data_collator=collator,
    neftune_noise_alpha=5
)

# Train
trainer.train()
trainer.save_model("./sft_final_model")

Best Practices

  • Ensure property values are normalized to appropriate ranges
  • Include diverse molecular scaffolds in training data
  • Balance the dataset across property value ranges
  • Validate SMILES strings before training
  • Start with smaller learning rates (1e-5 to 1e-4) for fine-tuning
  • Use warmup steps (10-20% of training steps)
  • Monitor validation loss to prevent overfitting
  • Adjust neftune_noise_alpha for better generalization
  • Set train_tol_level based on oracle budget (typically 2-3)
  • Use larger pool sizes (100-200) for better diversity
  • Keep validation_perc around 0.2 for reliable evaluation
  • Save checkpoints periodically during long optimizations
  • Track top-1, top-10, top-100 scores during optimization
  • Monitor validation loss during fine-tuning
  • Check molecular validity and diversity of generated molecules
  • Use TensorBoard or Weights & Biases for visualization
When using rejection sampling, ensure your oracle has appropriate caching to avoid redundant calculations. The optimization process may evaluate the same molecule multiple times.

Next Steps

Custom Oracles

Build custom scoring functions for optimization

Benchmarking

Evaluate model performance on standard benchmarks

Build docs developers (and LLMs) love