Overview
ChemLactica supports supervised fine-tuning (SFT) to adapt pretrained models for specific chemical tasks. The SFT framework uses TRL’s SFTTrainer with custom data collators for completion-only training.
Supervised Fine-Tuning (SFT)
Quick Start
Prepare SFT Data
Format your data with prompts and completions in JSONL format
Configure SFT Parameters
Set up packing, sequence length, and noise parameters
Run SFT Training
Launch training with the sft train type
Basic SFT Command
python chemlactica/train.py \
--train_type sft \
--from_pretrained /path/to/pretrained/model \
--model_config 125m \
--training_data_dirs /path/to/sft/data \
--dir_data_types activity \
--valid_data_dir /path/to/valid/data \
--learning_rate 1.0e-4 \
--warmup_steps 0 \
--train_batch_size 8 \
--eval_steps 100 \
--save_steps 500 \
--max_steps 5000 \
--checkpoints_root_dir ./checkpoints \
--experiment_name activity_prediction_sft
SFT Configuration
SFT-specific parameters are defined in config/default_train_config.py:
@dataclass
class SFTTrainConfig :
packing: bool = False
max_seq_length: int = 64
neftune_noise_alpha: int = 10
SFT Parameters
Enable packing multiple samples into a single sequence for efficiency
Maximum sequence length for SFT samples
NEFTune noise parameter for embedding perturbation (improves generalization)
Response Template
The SFT trainer uses a response template to train only on completions:
# From get_trainer.py:28
response_template = tokenizer.encode( "[PROPERTY]activity" )
collator = DataCollatorForCompletionOnlyLM(
response_template, tokenizer = tokenizer
)
JSONL Format
Formatting Function
{ "text" : "[SIMILAR]CC(C)C[/SIMILAR][QED]0.85[/QED][PROPERTY]activity 0.92" }
{ "text" : "[SIMILAR]c1ccccc1[/SIMILAR][QED]0.78[/QED][PROPERTY]activity 0.88" }
Training Configuration
Recommended Settings for SFT
train_config :
adam_beta1 : 0.9
adam_beta2 : 0.95
batch_size : 500000
global_gradient_norm : 1.0
max_learning_rate : 1.0e-4 # Lower than pretraining
warmup_steps : 0 # Usually no warmup for SFT
weight_decay : 0.1
bf16 : true
evaluation_strategy : "steps"
save_total_limit : 4
Advanced SFT Features
NEFTune (Noisy Embeddings)
NEFTune adds noise to embeddings during training to improve generalization:
# Configured in SFTTrainConfig
neftune_noise_alpha: int = 10 # Noise scale
This technique is particularly effective for instruction following and has been shown to improve downstream performance.
Completion-Only Training
The DataCollatorForCompletionOnlyLM ensures loss is computed only on the completion tokens:
# From get_trainer.py:29-31
collator = DataCollatorForCompletionOnlyLM(
response_template, tokenizer = tokenizer
)
This focuses learning on generating the target output rather than memorizing the prompt.
SFT Numerical Evaluation
ChemLactica includes a custom callback for evaluating SFT models:
# From train.py:287-289
trainer_callback_dict[ "SFT numerical evaluation" ] = SFTNumericalEval(
dataset, aim_callback, model_config.separator_token
)
Example: Activity Prediction SFT
Data Preparation
Create training data with molecular descriptors and activity values:
{ "text" : "[SMILES]CC(C)Cc1ccc(C)cc1[/SMILES][MW]148.25[/MW][PROPERTY]activity 0.85" }
{ "text" : "[SMILES]c1ccc2[nH]ccc2c1[/SMILES][MW]117.15[/MW][PROPERTY]activity 0.92" }
Training Command
python chemlactica/train.py \
--train_type sft \
--from_pretrained ./checkpoints/pretrained/galactica-125m \
--model_config 125m \
--training_data_dirs ./data/activity_sft/train \
--dir_data_types activity \
--valid_data_dir ./data/activity_sft/valid \
--learning_rate 1.0e-4 \
--warmup_steps 0 \
--train_batch_size 16 \
--gradient_accumulation_steps 2 \
--eval_steps 200 \
--save_steps 500 \
--max_steps 5000 \
--checkpoints_root_dir ./checkpoints/sft \
--experiment_name activity_prediction \
--track \
--track_dir ./aim_logs
Multi-GPU SFT Training
For larger models or datasets:
accelerate launch --config_file config/accelerate_config.yaml \
chemlactica/train.py \
--train_type sft \
--from_pretrained ./checkpoints/pretrained/galactica-1.3b \
--model_config 1.3b \
--training_data_dirs ./data/sft/train \
--dir_data_types activity \
--valid_data_dir ./data/sft/valid \
--learning_rate 2.0e-5 \
--warmup_steps 100 \
--train_batch_size 4 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--flash_attn \
--eval_steps 100 \
--save_steps 500 \
--max_steps 10000 \
--checkpoints_root_dir ./checkpoints/sft \
--experiment_name large_model_sft
Evaluation-Only Mode
Evaluate a fine-tuned model without training:
python chemlactica/train.py \
--train_type sft \
--from_pretrained ./checkpoints/sft/checkpoint-5000 \
--model_config 125m \
--valid_data_dir ./data/sft/test \
--evaluate_only \
--valid_batch_size 32
SFT vs Pretraining
Aspect Pretraining SFT Learning Rate Higher (1.4e-3 - 6e-4) Lower (1e-4 - 2e-5) Warmup Steps 500+ 0-100 Trainer CustomTrainer SFTTrainer Data Format Raw text sequences Prompt-completion pairs Loss All tokens Completion tokens only Evaluation Perplexity Task-specific metrics
Best Practices
Use 10-100x lower learning rate than pretraining
Start with 1e-4 for small models
Use 2e-5 or lower for models >1B parameters
Monitor validation loss closely for overfitting
Ensure consistent formatting across all samples
Balance your dataset across different property values
Remove duplicates and low-quality samples
Use validation set from same distribution as training
Keep max_seq_length as small as possible
Typical values: 64-512 tokens
Longer sequences require more memory
Consider packing for short sequences
SFT typically requires 1000-10000 steps
Watch for overfitting after a few epochs
Use early stopping based on validation loss
Save checkpoints frequently for model selection
Troubleshooting
Model Not Learning
Check that response template matches your data format
Verify data formatting is consistent
Try increasing learning rate slightly
Ensure sufficient training steps
Overfitting
Reduce learning rate
Increase weight decay
Add more training data
Enable dropout if needed
Use NEFTune noise
Memory Issues
Reduce max_seq_length
Decrease train_batch_size
Increase gradient_accumulation_steps
Enable packing for short sequences
Use --gradient_checkpointing
Next Steps
Rejection Sampling Learn about rejection sampling fine-tuning strategy
Configuration Explore detailed configuration options