Skip to main content

Overview

ChemLactica supports supervised fine-tuning (SFT) to adapt pretrained models for specific chemical tasks. The SFT framework uses TRL’s SFTTrainer with custom data collators for completion-only training.

Supervised Fine-Tuning (SFT)

Quick Start

1

Prepare SFT Data

Format your data with prompts and completions in JSONL format
2

Configure SFT Parameters

Set up packing, sequence length, and noise parameters
3

Run SFT Training

Launch training with the sft train type

Basic SFT Command

python chemlactica/train.py \
  --train_type sft \
  --from_pretrained /path/to/pretrained/model \
  --model_config 125m \
  --training_data_dirs /path/to/sft/data \
  --dir_data_types activity \
  --valid_data_dir /path/to/valid/data \
  --learning_rate 1.0e-4 \
  --warmup_steps 0 \
  --train_batch_size 8 \
  --eval_steps 100 \
  --save_steps 500 \
  --max_steps 5000 \
  --checkpoints_root_dir ./checkpoints \
  --experiment_name activity_prediction_sft

SFT Configuration

SFT-specific parameters are defined in config/default_train_config.py:
@dataclass
class SFTTrainConfig:
    packing: bool = False
    max_seq_length: int = 64
    neftune_noise_alpha: int = 10

SFT Parameters

packing
boolean
default:"false"
Enable packing multiple samples into a single sequence for efficiency
max_seq_length
integer
default:"64"
Maximum sequence length for SFT samples
neftune_noise_alpha
integer
default:"10"
NEFTune noise parameter for embedding perturbation (improves generalization)

Data Format

Response Template

The SFT trainer uses a response template to train only on completions:
# From get_trainer.py:28
response_template = tokenizer.encode("[PROPERTY]activity")
collator = DataCollatorForCompletionOnlyLM(
    response_template, tokenizer=tokenizer
)

Example Data Format

{"text": "[SIMILAR]CC(C)C[/SIMILAR][QED]0.85[/QED][PROPERTY]activity 0.92"}
{"text": "[SIMILAR]c1ccccc1[/SIMILAR][QED]0.78[/QED][PROPERTY]activity 0.88"}

Training Configuration

train_config:
  adam_beta1: 0.9
  adam_beta2: 0.95
  batch_size: 500000
  global_gradient_norm: 1.0
  max_learning_rate: 1.0e-4  # Lower than pretraining
  warmup_steps: 0             # Usually no warmup for SFT
  weight_decay: 0.1
  bf16: true
  evaluation_strategy: "steps"
  save_total_limit: 4

Advanced SFT Features

NEFTune (Noisy Embeddings)

NEFTune adds noise to embeddings during training to improve generalization:
# Configured in SFTTrainConfig
neftune_noise_alpha: int = 10  # Noise scale
This technique is particularly effective for instruction following and has been shown to improve downstream performance.

Completion-Only Training

The DataCollatorForCompletionOnlyLM ensures loss is computed only on the completion tokens:
# From get_trainer.py:29-31
collator = DataCollatorForCompletionOnlyLM(
    response_template, tokenizer=tokenizer
)
This focuses learning on generating the target output rather than memorizing the prompt.

SFT Numerical Evaluation

ChemLactica includes a custom callback for evaluating SFT models:
# From train.py:287-289
trainer_callback_dict["SFT numerical evaluation"] = SFTNumericalEval(
    dataset, aim_callback, model_config.separator_token
)

Example: Activity Prediction SFT

Data Preparation

Create training data with molecular descriptors and activity values:
{"text": "[SMILES]CC(C)Cc1ccc(C)cc1[/SMILES][MW]148.25[/MW][PROPERTY]activity 0.85"}
{"text": "[SMILES]c1ccc2[nH]ccc2c1[/SMILES][MW]117.15[/MW][PROPERTY]activity 0.92"}

Training Command

python chemlactica/train.py \
  --train_type sft \
  --from_pretrained ./checkpoints/pretrained/galactica-125m \
  --model_config 125m \
  --training_data_dirs ./data/activity_sft/train \
  --dir_data_types activity \
  --valid_data_dir ./data/activity_sft/valid \
  --learning_rate 1.0e-4 \
  --warmup_steps 0 \
  --train_batch_size 16 \
  --gradient_accumulation_steps 2 \
  --eval_steps 200 \
  --save_steps 500 \
  --max_steps 5000 \
  --checkpoints_root_dir ./checkpoints/sft \
  --experiment_name activity_prediction \
  --track \
  --track_dir ./aim_logs

Multi-GPU SFT Training

For larger models or datasets:
accelerate launch --config_file config/accelerate_config.yaml \
  chemlactica/train.py \
  --train_type sft \
  --from_pretrained ./checkpoints/pretrained/galactica-1.3b \
  --model_config 1.3b \
  --training_data_dirs ./data/sft/train \
  --dir_data_types activity \
  --valid_data_dir ./data/sft/valid \
  --learning_rate 2.0e-5 \
  --warmup_steps 100 \
  --train_batch_size 4 \
  --gradient_accumulation_steps 8 \
  --gradient_checkpointing \
  --flash_attn \
  --eval_steps 100 \
  --save_steps 500 \
  --max_steps 10000 \
  --checkpoints_root_dir ./checkpoints/sft \
  --experiment_name large_model_sft

Evaluation-Only Mode

Evaluate a fine-tuned model without training:
python chemlactica/train.py \
  --train_type sft \
  --from_pretrained ./checkpoints/sft/checkpoint-5000 \
  --model_config 125m \
  --valid_data_dir ./data/sft/test \
  --evaluate_only \
  --valid_batch_size 32

SFT vs Pretraining

AspectPretrainingSFT
Learning RateHigher (1.4e-3 - 6e-4)Lower (1e-4 - 2e-5)
Warmup Steps500+0-100
TrainerCustomTrainerSFTTrainer
Data FormatRaw text sequencesPrompt-completion pairs
LossAll tokensCompletion tokens only
EvaluationPerplexityTask-specific metrics

Best Practices

  • Use 10-100x lower learning rate than pretraining
  • Start with 1e-4 for small models
  • Use 2e-5 or lower for models >1B parameters
  • Monitor validation loss closely for overfitting
  • Ensure consistent formatting across all samples
  • Balance your dataset across different property values
  • Remove duplicates and low-quality samples
  • Use validation set from same distribution as training
  • Keep max_seq_length as small as possible
  • Typical values: 64-512 tokens
  • Longer sequences require more memory
  • Consider packing for short sequences
  • SFT typically requires 1000-10000 steps
  • Watch for overfitting after a few epochs
  • Use early stopping based on validation loss
  • Save checkpoints frequently for model selection

Troubleshooting

Model Not Learning

  • Check that response template matches your data format
  • Verify data formatting is consistent
  • Try increasing learning rate slightly
  • Ensure sufficient training steps

Overfitting

  • Reduce learning rate
  • Increase weight decay
  • Add more training data
  • Enable dropout if needed
  • Use NEFTune noise

Memory Issues

  • Reduce max_seq_length
  • Decrease train_batch_size
  • Increase gradient_accumulation_steps
  • Enable packing for short sequences
  • Use --gradient_checkpointing

Next Steps

Rejection Sampling

Learn about rejection sampling fine-tuning strategy

Configuration

Explore detailed configuration options

Build docs developers (and LLMs) love