Skip to main content
Supervised fine-tuning script that converts a pretrained base model into a conversational chat model.

Usage

# Single GPU
python -m scripts.chat_sft

# Distributed (8 GPUs)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16

Parameters

Logging

--run
str
default:"dummy"
Weights & Biases run name. Use 'dummy' to disable wandb logging.

Runtime

--device-type
str
default:""
Device type: cuda, cpu, or mps. Empty string enables autodetection.

Model Loading

--model-tag
str
default:"None"
Model tag to load from base checkpoints (e.g. d24).
--model-step
int
default:"None"
Model step to load. If not specified, loads the last checkpoint.
--load-optimizer
int
default:"1"
Warm-start optimizer from pretrained checkpoint. 0 = no, 1 = yes.

Training Horizon

--num-iterations
int
default:"-1"
Number of optimization steps. -1 = full epoch through the training dataset.

Batch Sizes

Defaults are inherited from the pretrained checkpoint if not specified.
--max-seq-len
int
default:"None"
Maximum context length. Default: inherit from pretrain.
--device-batch-size
int
default:"None"
Per-device batch size. Default: inherit from pretrain.
--total-batch-size
int
default:"None"
Total batch size in tokens. Default: inherit from pretrain.

Optimization

Defaults are inherited from the pretrained checkpoint if not specified.
--embedding-lr
float
default:"None"
Learning rate for embedding parameters (Adam). Default: inherit from pretrain.
--unembedding-lr
float
default:"None"
Learning rate for unembedding parameters (Adam). Default: inherit from pretrain.
--matrix-lr
float
default:"None"
Learning rate for matrix parameters (Muon). Default: inherit from pretrain.
--init-lr-frac
float
default:"0.8"
Initial learning rate as fraction of base learning rate.
--warmup-ratio
float
default:"0.0"
Ratio of iterations for learning rate warmup.
--warmdown-ratio
float
default:"0.5"
Ratio of iterations for learning rate warmdown.
--final-lr-frac
float
default:"0.0"
Final learning rate as fraction of initial learning rate.

Evaluation

--eval-every
int
default:"200"
Evaluate validation bits-per-byte every N steps. -1 = disabled.
--eval-tokens
int
default:"20971520"
Number of tokens to evaluate validation loss on (default: 40*524288).
--chatcore-every
int
default:"200"
Evaluate ChatCORE metric every N steps. -1 = disabled.
--chatcore-max-cat
int
default:"-1"
Maximum problems per categorical task for ChatCORE. -1 = all problems.
--chatcore-max-sample
int
default:"24"
Maximum problems per generative task for ChatCORE.

Data Mixture

--mmlu-epochs
int
default:"3"
Number of epochs of MMLU in training mixture (teaches multiple choice).
--gsm8k-epochs
int
default:"4"
Number of epochs of GSM8K in training mixture (teaches math and tool use).

Training Mixture

The SFT script uses a carefully balanced mixture of tasks:
  • SmolTalk (460K rows): General conversations
  • Identity Conversations (1K rows × 2 epochs): Synthetic identity conversations
  • MMLU (100K rows × --mmlu-epochs): Multiple choice questions
  • GSM8K (8K rows × --gsm8k-epochs): Math word problems with tool use
  • Simple Spelling (200K rows): Basic spelling tasks
  • Spelling Bee (80K rows): Character counting tasks

Examples

Basic SFT Training

torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft \
  --run=my-sft-run \
  --model-tag=d24

Custom Data Mixture

torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft \
  --model-tag=d24 \
  --mmlu-epochs=5 \
  --gsm8k-epochs=2

Override Learning Rate

torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft \
  --model-tag=d24 \
  --matrix-lr=0.01 \
  --init-lr-frac=0.5

Fixed Number of Iterations

torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft \
  --model-tag=d24 \
  --num-iterations=1000

ChatCORE Metric

The ChatCORE metric evaluates the chat model across 6 tasks:
  • ARC-Easy (categorical)
  • ARC-Challenge (categorical)
  • MMLU (categorical)
  • GSM8K (generative)
  • HumanEval (generative)
  • SpellingBee (generative)
Like CORE, it uses centered accuracy to normalize performance from 0 (random) to 1 (perfect).

Build docs developers (and LLMs) love