run_sft
from modern_llm.training.train_sft import run_sft
Run supervised fine-tuning on a pretrained model using instruction-following data. Implements the SFT stage from InstructGPT (Ouyang et al., 2022) with response-only loss masking.
Parameters
Path to pretrained model checkpoint containing model_state and config keys.
Training configuration with hyperparameters, batch sizes, and logging settings.
dataset_config
InstructionDatasetConfig
required
Dataset configuration specifying which instruction dataset to use and formatting options.
HuggingFace tokenizer identifier. Must match the tokenizer used during pretraining.
eval_split
Optional[str]
default:"None"
Optional evaluation split name (e.g., “test” or “validation”). If provided, runs evaluation during training.
Returns
Path to the final SFT checkpoint (e.g., experiments/runs/sft_final.pt).
Usage
from pathlib import Path
from modern_llm.config import TrainingConfig
from modern_llm.data.instruction_datasets import InstructionDatasetConfig
from modern_llm.training.train_sft import run_sft
# Point to pretrained checkpoint
pretrain_ckpt = Path("experiments/runs/pretrain_final.pt")
# Configure SFT training
train_config = TrainingConfig(
run_name="sft-alpaca",
dataset_name="tatsu-lab/alpaca",
tokenizer_name="gpt2",
output_dir=Path("experiments/sft"),
batch_size=32,
micro_batch_size=2,
gradient_accumulation_steps=16,
learning_rate=1e-5,
max_steps=5000,
warmup_steps=100,
weight_decay=0.01,
eval_every=500,
save_every=1000,
log_every=50,
mixed_precision="bf16",
)
# Configure instruction dataset
dataset_config = InstructionDatasetConfig(
dataset_name="tatsu-lab/alpaca",
max_length=1024,
split="train",
num_examples=50000,
)
# Run SFT
sft_checkpoint = run_sft(
pretrain_checkpoint=pretrain_ckpt,
train_config=train_config,
dataset_config=dataset_config,
tokenizer_name="gpt2",
eval_split="test",
)
print(f"SFT complete: {sft_checkpoint}")
Supported Datasets
Common instruction datasets supported:
- Alpaca:
tatsu-lab/alpaca
- Dolly:
databricks/databricks-dolly-15k
- FLAN:
Muennighoff/flan
- ShareGPT:
RyokoAI/ShareGPT52K
# Fine-tune on Dolly
dataset_config = InstructionDatasetConfig(
dataset_name="databricks/databricks-dolly-15k",
max_length=1024,
split="train",
)
sft_checkpoint = run_sft(
pretrain_checkpoint=pretrain_ckpt,
train_config=train_config,
dataset_config=dataset_config,
)
Training Configuration
Learning Rate
SFT typically uses a lower learning rate than pretraining:
train_config = TrainingConfig(
learning_rate=1e-5, # 10-100x lower than pretraining
warmup_steps=100,
weight_decay=0.01,
max_grad_norm=1.0,
)
Batch Size
Use gradient accumulation for large effective batch sizes:
train_config = TrainingConfig(
batch_size=64, # Effective batch size
micro_batch_size=2, # Per-GPU batch size
gradient_accumulation_steps=32, # 64 / 2 = 32
)
Sequence Length
Instruction datasets may need longer sequences than pretraining:
dataset_config = InstructionDatasetConfig(
dataset_name="tatsu-lab/alpaca",
max_length=2048, # Longer for instruction + response
split="train",
)
Response-Only Loss Masking
The SFT trainer automatically masks the loss on prompt tokens, computing loss only on the assistant’s response. This focuses learning on generating good responses rather than memorizing prompts.
Evaluation
Provide an eval split to monitor validation loss during training:
sft_checkpoint = run_sft(
pretrain_checkpoint=pretrain_ckpt,
train_config=train_config,
dataset_config=dataset_config,
eval_split="test", # Run eval on test set
)
Evaluation runs every train_config.eval_every steps and logs validation metrics.
Pipeline Integration
SFT is typically the second stage after pretraining:
from modern_llm.training.train_lm import run_training
from modern_llm.training.train_sft import run_sft
# Stage 1: Pretrain
pretrain_ckpt = run_training(
model_config=model_config,
train_config=pretrain_config,
dataset_names=["wikitext-2-raw-v1"],
)
# Stage 2: SFT
sft_ckpt = run_sft(
pretrain_checkpoint=pretrain_ckpt,
train_config=sft_config,
dataset_config=instruction_config,
)