Skip to main content

run_sft

from modern_llm.training.train_sft import run_sft
Run supervised fine-tuning on a pretrained model using instruction-following data. Implements the SFT stage from InstructGPT (Ouyang et al., 2022) with response-only loss masking.

Parameters

pretrain_checkpoint
Path
required
Path to pretrained model checkpoint containing model_state and config keys.
train_config
TrainingConfig
required
Training configuration with hyperparameters, batch sizes, and logging settings.
dataset_config
InstructionDatasetConfig
required
Dataset configuration specifying which instruction dataset to use and formatting options.
tokenizer_name
str
default:"gpt2"
HuggingFace tokenizer identifier. Must match the tokenizer used during pretraining.
eval_split
Optional[str]
default:"None"
Optional evaluation split name (e.g., “test” or “validation”). If provided, runs evaluation during training.

Returns

checkpoint_path
Path
Path to the final SFT checkpoint (e.g., experiments/runs/sft_final.pt).

Usage

from pathlib import Path
from modern_llm.config import TrainingConfig
from modern_llm.data.instruction_datasets import InstructionDatasetConfig
from modern_llm.training.train_sft import run_sft

# Point to pretrained checkpoint
pretrain_ckpt = Path("experiments/runs/pretrain_final.pt")

# Configure SFT training
train_config = TrainingConfig(
    run_name="sft-alpaca",
    dataset_name="tatsu-lab/alpaca",
    tokenizer_name="gpt2",
    output_dir=Path("experiments/sft"),
    batch_size=32,
    micro_batch_size=2,
    gradient_accumulation_steps=16,
    learning_rate=1e-5,
    max_steps=5000,
    warmup_steps=100,
    weight_decay=0.01,
    eval_every=500,
    save_every=1000,
    log_every=50,
    mixed_precision="bf16",
)

# Configure instruction dataset
dataset_config = InstructionDatasetConfig(
    dataset_name="tatsu-lab/alpaca",
    max_length=1024,
    split="train",
    num_examples=50000,
)

# Run SFT
sft_checkpoint = run_sft(
    pretrain_checkpoint=pretrain_ckpt,
    train_config=train_config,
    dataset_config=dataset_config,
    tokenizer_name="gpt2",
    eval_split="test",
)

print(f"SFT complete: {sft_checkpoint}")

Supported Datasets

Common instruction datasets supported:
  • Alpaca: tatsu-lab/alpaca
  • Dolly: databricks/databricks-dolly-15k
  • FLAN: Muennighoff/flan
  • ShareGPT: RyokoAI/ShareGPT52K
# Fine-tune on Dolly
dataset_config = InstructionDatasetConfig(
    dataset_name="databricks/databricks-dolly-15k",
    max_length=1024,
    split="train",
)

sft_checkpoint = run_sft(
    pretrain_checkpoint=pretrain_ckpt,
    train_config=train_config,
    dataset_config=dataset_config,
)

Training Configuration

Learning Rate SFT typically uses a lower learning rate than pretraining:
train_config = TrainingConfig(
    learning_rate=1e-5,  # 10-100x lower than pretraining
    warmup_steps=100,
    weight_decay=0.01,
    max_grad_norm=1.0,
)
Batch Size Use gradient accumulation for large effective batch sizes:
train_config = TrainingConfig(
    batch_size=64,        # Effective batch size
    micro_batch_size=2,   # Per-GPU batch size
    gradient_accumulation_steps=32,  # 64 / 2 = 32
)
Sequence Length Instruction datasets may need longer sequences than pretraining:
dataset_config = InstructionDatasetConfig(
    dataset_name="tatsu-lab/alpaca",
    max_length=2048,  # Longer for instruction + response
    split="train",
)

Response-Only Loss Masking

The SFT trainer automatically masks the loss on prompt tokens, computing loss only on the assistant’s response. This focuses learning on generating good responses rather than memorizing prompts.

Evaluation

Provide an eval split to monitor validation loss during training:
sft_checkpoint = run_sft(
    pretrain_checkpoint=pretrain_ckpt,
    train_config=train_config,
    dataset_config=dataset_config,
    eval_split="test",  # Run eval on test set
)
Evaluation runs every train_config.eval_every steps and logs validation metrics.

Pipeline Integration

SFT is typically the second stage after pretraining:
from modern_llm.training.train_lm import run_training
from modern_llm.training.train_sft import run_sft

# Stage 1: Pretrain
pretrain_ckpt = run_training(
    model_config=model_config,
    train_config=pretrain_config,
    dataset_names=["wikitext-2-raw-v1"],
)

# Stage 2: SFT
sft_ckpt = run_sft(
    pretrain_checkpoint=pretrain_ckpt,
    train_config=sft_config,
    dataset_config=instruction_config,
)

Build docs developers (and LLMs) love