Overview
SFT adapts the pretrained model’s language understanding capabilities to the instruction-following task format. The model learns to:- Follow natural language instructions
- Answer questions accurately
- Generate helpful, harmless responses
- Maintain conversational context
Supported datasets
Modern LLM supports several high-quality instruction datasets:| Dataset | Examples | Format | Description |
|---|---|---|---|
| tatsu-lab/alpaca | 52K | Instruction-input-output | Stanford Alpaca dataset with GPT-3.5-generated responses |
| databricks/databricks-dolly-15k | 15K | Instruction-context-response | Human-written instruction-following examples |
| Open-Orca/OpenOrca | 4.2M | System-question-response | Large-scale instruction dataset based on FLAN |
The
gpu preset uses multiple SFT datasets to improve instruction-following diversity. You can specify multiple datasets with the sft_datasets config parameter.Usage
Using the pipeline runner (recommended)
Direct script usage
You can also use the standalone SFT script:Configuration
Config presets
SFT hyperparameters are defined in the pipeline config presets:SFT learning rates are typically 10-100x lower than pretraining to avoid catastrophic forgetting of the pretrained knowledge.
Hyperparameter tuning
Key hyperparameters for SFT: Learning rate (sft_lr)
- Default:
1e-5balances adaptation and stability - Too high: Model forgets pretrained knowledge
- Too low: Slow adaptation to instruction format
- Range:
5e-6to5e-5
sft_max_steps)
- Default:
5000for single dataset,10000for multiple - Small datasets (Alpaca): 3000-5000 steps sufficient
- Large datasets (OpenOrca): 10000+ steps for full coverage
- Stop early if validation loss plateaus
sft_batch_size)
- Default:
32provides stable gradients - Smaller batches = more frequent updates
- Larger batches = smoother but slower convergence
Training details
Optimization
SFT uses:- Optimizer: AdamW with β₁=0.9, β₂=0.95
- Learning rate schedule: Cosine annealing from
sft_lrto 0 - Gradient accumulation: Automatic (batch_size / micro_batch_size)
- Mixed precision: BF16 on supported GPUs
- Weight decay: 0.01 (lighter than pretraining)
Loss function
Causal language modeling loss with response-only masking:Data format
Instruction datasets are formatted as:Checkpoints
SFT saves checkpoints at regular intervals:-
Regular checkpoints: Every
save_everysteps (default: 2000)- Format:
<run_name>-sft_step{N}.pt - Contains model state, optimizer state, config
- Format:
-
Final checkpoint: At end of training
- Format:
<run_name>-sft_final.pt - Used as input for DPO stage
- Format:
Monitoring
SFT training progress is logged to console andtraining.log:
Quality indicators
Good SFT training:- Loss decreases steadily from ~1.5 to ~0.8
- No sudden spikes or NaN losses
- Validation loss follows training loss
- Training loss continues decreasing but validation loss increases
- Model generates repetitive or memorized responses
- Solution: Reduce steps, increase weight decay, or add more data
- Loss plateaus early at high value (>1.0)
- Model fails to follow basic instructions
- Solution: Train longer, increase LR, or check data quality
Implementation details
SFT implementation is located at:src/modern_llm/training/train_sft.py:run_sft()- Main training loopscripts/sft.py- CLI wrapperscripts/run_pipeline.py:run_sft()- Pipeline integration
src/modern_llm/training/train_sft.py:52-131
Main SFT entrypoint:
- Load pretrained model from checkpoint
- Load and format instruction dataset
- Setup optimizer with cosine annealing
- Run training with response-only masking
- Save final SFT checkpoint
- Fetch dataset from Hugging Face
- Apply instruction-response formatting
- Tokenize with response masking
- Return PyTorch Dataset
Performance tips
Reduce memory usage
Reduce memory usage
- Lower
micro_batch_size(use gradient accumulation) - Reduce
max_seq_lento 512 or 768 - Use LoRA/QLoRA for parameter-efficient fine-tuning (not yet supported)
- Enable gradient checkpointing
Speed up training
Speed up training
- Increase
micro_batch_sizeif GPU memory allows - Use shorter max sequence length
- Sample large datasets (e.g., OpenOrca:50000)
- Enable bf16 mixed precision
Improve instruction following
Improve instruction following
- Train on multiple diverse datasets
- Increase training steps (10K+)
- Use curriculum learning (start with simple instructions)
- Add evaluation on held-out test set
Prevent catastrophic forgetting
Prevent catastrophic forgetting
- Keep learning rate low (1e-5 or lower)
- Reduce training steps if model degrades
- Mix in pretraining data (10-20% of batches)
- Use lower weight decay
Evaluation
After SFT, evaluate the model’s instruction-following ability:Quick qualitative check
Generate responses to test instructions:Quantitative benchmarks
Run evaluation on standard benchmarks:Next steps
After SFT completes:- Verify the checkpoint exists at
experiments/runs/<run_name>/sft_final.pt - Run DPO to further align the model:
- Or continue with full pipeline:
Direct preference optimization
Learn how to align your SFT model using preference data