Skip to main content
Supervised fine-tuning (SFT) is the second stage of the Modern LLM pipeline. It takes a pretrained language model and fine-tunes it on instruction-response pairs to teach the model to follow user instructions and engage in helpful dialog.

Overview

SFT adapts the pretrained model’s language understanding capabilities to the instruction-following task format. The model learns to:
  • Follow natural language instructions
  • Answer questions accurately
  • Generate helpful, harmless responses
  • Maintain conversational context
This stage implements the supervised fine-tuning approach from InstructGPT (Ouyang et al., 2022), using standard causal language modeling with response-only masking to focus learning on the assistant’s responses.

Supported datasets

Modern LLM supports several high-quality instruction datasets:
DatasetExamplesFormatDescription
tatsu-lab/alpaca52KInstruction-input-outputStanford Alpaca dataset with GPT-3.5-generated responses
databricks/databricks-dolly-15k15KInstruction-context-responseHuman-written instruction-following examples
Open-Orca/OpenOrca4.2MSystem-question-responseLarge-scale instruction dataset based on FLAN
The gpu preset uses multiple SFT datasets to improve instruction-following diversity. You can specify multiple datasets with the sft_datasets config parameter.

Usage

python scripts/run_pipeline.py --config local --stage sft \
    --checkpoint experiments/runs/local-full/pretrain_final.pt

Direct script usage

You can also use the standalone SFT script:
python scripts/sft.py \
    --pretrain-checkpoint experiments/runs/local-full/pretrain_final.pt \
    --config local

Configuration

Config presets

SFT hyperparameters are defined in the pipeline config presets:
# Quick test (~2 minutes)
sft_max_steps: 50
sft_lr: 1e-5
sft_batch_size: 32
sft_micro_batch_size: 2
sft_dataset: "tatsu-lab/alpaca"
SFT learning rates are typically 10-100x lower than pretraining to avoid catastrophic forgetting of the pretrained knowledge.

Hyperparameter tuning

Key hyperparameters for SFT: Learning rate (sft_lr)
  • Default: 1e-5 balances adaptation and stability
  • Too high: Model forgets pretrained knowledge
  • Too low: Slow adaptation to instruction format
  • Range: 5e-6 to 5e-5
Training steps (sft_max_steps)
  • Default: 5000 for single dataset, 10000 for multiple
  • Small datasets (Alpaca): 3000-5000 steps sufficient
  • Large datasets (OpenOrca): 10000+ steps for full coverage
  • Stop early if validation loss plateaus
Batch size (sft_batch_size)
  • Default: 32 provides stable gradients
  • Smaller batches = more frequent updates
  • Larger batches = smoother but slower convergence

Training details

Optimization

SFT uses:
  • Optimizer: AdamW with β₁=0.9, β₂=0.95
  • Learning rate schedule: Cosine annealing from sft_lr to 0
  • Gradient accumulation: Automatic (batch_size / micro_batch_size)
  • Mixed precision: BF16 on supported GPUs
  • Weight decay: 0.01 (lighter than pretraining)

Loss function

Causal language modeling loss with response-only masking:
# Only compute loss on assistant responses, not instructions
mask = create_response_mask(input_ids, tokenizer)
loss = F.cross_entropy(
    logits[:, :-1, :].reshape(-1, vocab_size),
    labels[:, 1:].reshape(-1),
    reduction='none'
)
loss = (loss * mask).sum() / mask.sum()
This focuses the model’s learning on generating good responses rather than memorizing instruction formats.

Data format

Instruction datasets are formatted as:
{
  "instruction": "Write a haiku about spring.",
  "input": "",
  "output": "Cherry blossoms bloom\nSoft petals dance in the breeze\nSpring awakens now"
}
The loader automatically formats these into conversational templates:
Instruction: Write a haiku about spring.

Response: Cherry blossoms bloom
Soft petals dance in the breeze
Spring awakens now

Checkpoints

SFT saves checkpoints at regular intervals:
  • Regular checkpoints: Every save_every steps (default: 2000)
    • Format: <run_name>-sft_step{N}.pt
    • Contains model state, optimizer state, config
  • Final checkpoint: At end of training
    • Format: <run_name>-sft_final.pt
    • Used as input for DPO stage
Checkpoint structure:
checkpoint = {
    'model_state': OrderedDict(...),      # Fine-tuned model weights
    'optimizer_state': {...},             # Optimizer state
    'config': {...},                      # Model config (same as pretrain)
    'step': 5000,                         # SFT step counter
    'run_name': 'local-full-sft',
}
SFT checkpoints contain the full model weights, not deltas. They are completely independent of the pretrain checkpoint after training.

Monitoring

SFT training progress is logged to console and training.log:
Loading pretrained model from experiments/runs/local-full/pretrain_final.pt
Model: 117.2M parameters
Loading instruction dataset: tatsu-lab/alpaca
Training examples: 51760
Starting SFT for 5000 steps

SFT Training: 100%|████████| 5000/5000 [2:15:30<00:00, loss=0.8234]
step=1000 loss=1.2345 lr=9.511e-06
step=2000 loss=0.9876 lr=8.090e-06
step=3000 loss=0.8654 lr=6.180e-06
step=4000 loss=0.8123 lr=4.090e-06
step=5000 loss=0.7891 lr=1.545e-06

SFT complete. Final checkpoint: experiments/runs/local-full/sft_final.pt

Quality indicators

Good SFT training:
  • Loss decreases steadily from ~1.5 to ~0.8
  • No sudden spikes or NaN losses
  • Validation loss follows training loss
Signs of overfitting:
  • Training loss continues decreasing but validation loss increases
  • Model generates repetitive or memorized responses
  • Solution: Reduce steps, increase weight decay, or add more data
Signs of underfitting:
  • Loss plateaus early at high value (>1.0)
  • Model fails to follow basic instructions
  • Solution: Train longer, increase LR, or check data quality

Implementation details

SFT implementation is located at:
  • src/modern_llm/training/train_sft.py:run_sft() - Main training loop
  • scripts/sft.py - CLI wrapper
  • scripts/run_pipeline.py:run_sft() - Pipeline integration
Key functions: run_sft(pretrain_checkpoint, train_config, dataset_config, tokenizer_name) src/modern_llm/training/train_sft.py:52-131 Main SFT entrypoint:
  1. Load pretrained model from checkpoint
  2. Load and format instruction dataset
  3. Setup optimizer with cosine annealing
  4. Run training with response-only masking
  5. Save final SFT checkpoint
load_instruction_dataset(config, tokenizer) Loads instruction dataset and applies conversational template:
  1. Fetch dataset from Hugging Face
  2. Apply instruction-response formatting
  3. Tokenize with response masking
  4. Return PyTorch Dataset

Performance tips

  • Lower micro_batch_size (use gradient accumulation)
  • Reduce max_seq_len to 512 or 768
  • Use LoRA/QLoRA for parameter-efficient fine-tuning (not yet supported)
  • Enable gradient checkpointing
  • Increase micro_batch_size if GPU memory allows
  • Use shorter max sequence length
  • Sample large datasets (e.g., OpenOrca:50000)
  • Enable bf16 mixed precision
  • Train on multiple diverse datasets
  • Increase training steps (10K+)
  • Use curriculum learning (start with simple instructions)
  • Add evaluation on held-out test set
  • Keep learning rate low (1e-5 or lower)
  • Reduce training steps if model degrades
  • Mix in pretraining data (10-20% of batches)
  • Use lower weight decay

Evaluation

After SFT, evaluate the model’s instruction-following ability:

Quick qualitative check

Generate responses to test instructions:
from modern_llm.models import ModernDecoderLM
from modern_llm.utils.checkpointing import load_checkpoint
from transformers import AutoTokenizer

# Load SFT model
ckpt = load_checkpoint("experiments/runs/local-full/sft_final.pt")
model = ModernDecoderLM.from_checkpoint(ckpt)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Test instruction
prompt = "Instruction: Explain quantum computing in simple terms.\n\nResponse:"
response = model.generate(prompt, max_new_tokens=100, temperature=0.7)
print(response)

Quantitative benchmarks

Run evaluation on standard benchmarks:
# Evaluate on held-out Alpaca test set
python scripts/evaluate_pipeline.py \
    --checkpoint experiments/runs/local-full/sft_final.pt \
    --stage sft
See the evaluation guide for more details.

Next steps

After SFT completes:
  1. Verify the checkpoint exists at experiments/runs/<run_name>/sft_final.pt
  2. Run DPO to further align the model:
    python scripts/run_pipeline.py --config local --stage dpo \
        --checkpoint experiments/runs/local-full/sft_final.pt
    
  3. Or continue with full pipeline:
    python scripts/run_pipeline.py --config local --stage all
    

Direct preference optimization

Learn how to align your SFT model using preference data

Build docs developers (and LLMs) love