Skip to main content

Overview

Supervised Fine-Tuning (SFT) adapts a pretrained base model to follow instructions and engage in conversation. The SFT stage:
  1. Loads a pretrained base model checkpoint
  2. Trains on a mixture of conversational datasets
  3. Teaches the model chat format, tool use, and task-specific skills
  4. Optionally warm-starts the optimizer from pretraining

Quick Start

Single GPU:
python -m scripts.chat_sft
Multi-GPU (8 GPUs):
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft \
  --device-batch-size=16

Loading Pretrained Models

--model-tag
str
default:"None"
Model tag to load from base_checkpoints/. If not specified, loads the default model.
--model-step
int
default:"None"
Specific checkpoint step to load. If not specified, loads the latest checkpoint.
--load-optimizer
int
default:"1"
Warm-start optimizer from pretrained checkpoint (1=yes, 0=no).Recommended: Keep at 1 to reuse momentum buffers from pretraining.

Training Data Mixture

The default SFT mixture includes:
# Conversational data
SmolTalk(split="train")              # 460K general conversations
CustomJSON(identity_conversations)   # 1K identity conversations (2 epochs)

# Task-specific data  
MMLU × args.mmlu_epochs              # 100K/epoch (teaches multiple choice)
GSM8K × args.gsm8k_epochs            # 8K/epoch (teaches math and tool use)
SimpleSpelling(200K)                 # Spell words
SpellingBee(80K)                     # Count letters in words

Data Mixture Parameters

--mmlu-epochs
int
default:"3"
Number of epochs of MMLU in the training mixture. MMLU teaches multiple-choice question answering.
--gsm8k-epochs
int
default:"4"
Number of epochs of GSM8K in the training mixture. GSM8K teaches math reasoning and tool use.

Training Horizon

--num-iterations
int
default:"-1"
Number of optimization steps. -1 = train for one full epoch through the data mixture.

Batch Size

By default, SFT inherits batch size settings from the pretrained checkpoint.
--max-seq-len
int
default:"None"
Maximum context length. None = inherit from pretrained checkpoint (typically 2048).
--device-batch-size
int
default:"None"
Per-device batch size. None = inherit from pretrained checkpoint (typically 32).Reduce if you encounter OOM errors.
--total-batch-size
int
default:"None"
Total batch size in tokens across all devices. None = inherit from pretrained checkpoint (typically 524288).

Learning Rates

By default, SFT inherits learning rates from pretraining and scales them down:
--init-lr-frac
float
default:"0.8"
Initial learning rate as fraction of pretrained LRs.Example: If pretraining used matrix_lr=0.02, SFT starts at 0.02 × 0.8 = 0.016.
--matrix-lr
float
default:"None"
Learning rate for transformer matrices (Muon optimizer). None = inherit from pretraining (typically 0.02).
--embedding-lr
float
default:"None"
Learning rate for input embedding (Adam). None = inherit from pretraining (typically 0.3).
--unembedding-lr
float
default:"None"
Learning rate for output unembedding (Adam). None = inherit from pretraining (typically 0.004).

Learning Rate Schedule

--warmup-ratio
float
default:"0.0"
Fraction of iterations for linear LR warmup. 0.0 = no warmup.
--warmdown-ratio
float
default:"0.5"
Fraction of iterations for linear LR warmdown. 0.5 = decay in last half of training.
--final-lr-frac
float
default:"0.0"
Final LR as fraction of initial LR. 0.0 = decay to zero.

Evaluation

--eval-every
int
default:"200"
Evaluate validation bits-per-byte every N steps. -1 to disable.
--eval-tokens
int
default:"20971520"
Number of tokens for validation evaluation (default: 40 × 524288).
--chatcore-every
int
default:"200"
Evaluate ChatCORE metric every N steps. -1 to disable.ChatCORE measures performance on:
  • ARC-Easy, ARC-Challenge (science Q&A)
  • MMLU (knowledge)
  • GSM8K (math reasoning)
  • HumanEval (code generation)
  • SpellingBee (spelling)
--chatcore-max-cat
int
default:"-1"
Max problems per categorical task (ARC, MMLU) for ChatCORE. -1 = no limit.
--chatcore-max-sample
int
default:"24"
Max problems per generative task (GSM8K, HumanEval) for ChatCORE.

Logging

--run
str
default:"dummy"
Wandb run name. Set to “dummy” to disable wandb logging.
--device-type
str
default:""
Device type: cuda, cpu, or mps. Empty string = autodetect.

BOS-Aligned Best-Fit Packing

SFT uses a specialized dataloader that:
  1. BOS-aligned: Each row starts with a <|bos|> token (beginning of conversation)
  2. Best-fit packing: Multiple conversations are packed into each sequence using a best-fit algorithm
  3. Padding instead of cropping: When no conversation fits, the row is padded (no tokens discarded)
  4. Target masking: Padding positions have targets set to -1 (ignored by loss)
This ensures:
  • Maximum token efficiency (no wasted computation)
  • No information loss (all training tokens are seen)
  • Clean conversation boundaries

Example Workflows

Basic SFT on pretrained d12 model

torchrun --nproc_per_node=8 -m scripts.chat_sft \
  --model-tag=d12 \
  --run=sft_d12

SFT with custom data mixture

torchrun --nproc_per_node=8 -m scripts.chat_sft \
  --model-tag=d20 \
  --mmlu-epochs=5 \
  --gsm8k-epochs=6 \
  --run=sft_d20_heavy_math

SFT with custom learning rates

torchrun --nproc_per_node=8 -m scripts.chat_sft \
  --model-tag=d16 \
  --init-lr-frac=0.5 \
  --warmdown-ratio=0.3 \
  --run=sft_d16_conservative

SFT without optimizer warm-start

python -m scripts.chat_sft \
  --model-tag=d12 \
  --load-optimizer=0 \
  --run=sft_from_scratch

SFT with fixed number of iterations

torchrun --nproc_per_node=8 -m scripts.chat_sft \
  --model-tag=d20 \
  --num-iterations=5000 \
  --run=sft_5k_steps

Output

Checkpoints are saved to $NANOCHAT_BASE_DIR/chatsft_checkpoints/{model_tag}/:
  • step_{N}_model.pt - Model weights
  • step_{N}_optimizer.pt - Optimizer state
  • step_{N}_meta.json - Metadata (config, validation loss, ChatCORE scores)
Only the final checkpoint is saved (at the end of training).

Monitoring

Key metrics logged to console and wandb:
  • train/loss - Training loss
  • val/bpb - Validation bits per byte
  • chatcore_metric - Overall ChatCORE score (centered mean across tasks)
  • chatcore_cat - ChatCORE on categorical tasks only (ARC, MMLU)
  • chatcore/[task] - Per-task accuracy (ARC-Easy, ARC-Challenge, MMLU, GSM8K, HumanEval, SpellingBee)
  • train/epoch - Current epoch through the dataset

Identity Conversations

The SFT mixture includes synthetic identity conversations from identity_conversations.jsonl. These teach the model:
  • Its name and identity
  • How to respond to meta questions (“Who are you?”, “Who made you?”)
  • Appropriate disclaimers and limitations
The identity file is loaded from $NANOCHAT_BASE_DIR/identity_conversations.jsonl.

Optimizer Warm-Starting

When --load-optimizer=1 (default):
  1. Loads optimizer state from pretrained checkpoint
  2. Preserves momentum buffers (useful for stability)
  3. Resets learning rates to SFT values (ignoring pretrained LRs which were warmed down to ~0)
This improves training stability and convergence compared to training from scratch.

Weight Decay

Note: SFT uses weight_decay=0.0 because:
  • Pretraining already ramped weight decay to zero by end of training
  • SFT continues with zero weight decay for fine-grained adaptation
  • No regularization is needed on the small SFT dataset after strong pretraining

Build docs developers (and LLMs) love