train_gpt.py is driven by environment variables. This page is a consolidated reference organized by subsystem.
Quick-Start Example
Training Hyperparameters
All variables in this section are read by theHyperparameters class at process startup. Unset variables fall back to the listed defaults.
Data Paths
Root directory for tokenized dataset shards. Train and val glob patterns (
fineweb_train_*.bin and fineweb_val_*.bin) are derived from this path.Path to the SentencePiece
.model file. Used to build look-up tables for the tokenizer-agnostic BPB metric. Must match VOCAB_SIZE exactly or training raises an error.Human-readable identifier for this run. Determines the log filename at
logs/<RUN_ID>.txt.Global random seed applied to Python, NumPy, and PyTorch (including
cuda.manual_seed_all) before training.Validation
Total token budget across all ranks per validation pass. Must provide at least one full
TRAIN_SEQ_LEN-length sequence per rank.Run validation every N training steps. Set to
0 to disable periodic validation (final evaluation still runs at the end).Log a
train_loss line every N steps. Steps 1–10 are always logged regardless of this setting.Training Length
Maximum number of gradient update steps before training stops. The wallclock cap may cause an earlier stop.
Number of steps (or equivalent wallclock duration) over which the learning rate linearly decays to zero at training end.
Number of pre-training “warmup” steps that prime compiled kernels. Model and optimizer state are fully reset after warmup completes, so effective training always starts from the true initialization.
Total tokens consumed per gradient update across all ranks. Gradient accumulation steps =
8 // WORLD_SIZE.Sequence length for both training and validation. Affects memory usage and the minimum
VAL_BATCH_SIZE.Hard cap on training time in seconds. When elapsed training time reaches this limit, training stops after the current step finishes. Set to
0 to disable the cap.Initial value for the per-head learnable
q_gain parameter in each attention block. Scales query vectors before the dot product.Model Shape
Vocabulary size. Must exactly match the SentencePiece tokenizer’s vocab size.
Total number of transformer blocks. Split evenly into encoder and decoder halves for U-Net-style skip connections.
Number of key/value heads for Grouped Query Attention (GQA). Must evenly divide
NUM_HEADS.Hidden/embedding dimension. Must be divisible by
NUM_HEADS, and MODEL_DIM // NUM_HEADS must be even (required for RoPE).Number of query attention heads.
MLP hidden-layer multiplier. The feedforward hidden size is
MLP_MULT * MODEL_DIM.Set to
1 to tie input embedding and output projection weights (saves parameters). Set to 0 for a separate lm_head.Base frequency for Rotary Position Embeddings.
Logit soft-cap. Applied as
softcap * tanh(logits / softcap) before cross-entropy. Must be positive.Optimizer
Adam learning rate for the token embedding when
TIE_EMBEDDINGS=0.Adam learning rate for the untied
lm_head when TIE_EMBEDDINGS=0.Adam learning rate for the token embedding when
TIE_EMBEDDINGS=1.Standard deviation for normal initialization of the tied embedding weight.
Muon learning rate for 2D matrix parameters in transformer blocks.
Adam learning rate for scalar and vector parameters (scales, norms, gains) in transformer blocks.
Steady-state momentum for the Muon optimizer.
Number of Newton-Schulz iterations used to orthogonalize gradient matrices in Muon.
Starting Muon momentum value at step 0, linearly warmed up to
MUON_MOMENTUM over MUON_MOMENTUM_WARMUP_STEPS steps.Steps over which Muon momentum is linearly warmed from
MUON_MOMENTUM_WARMUP_START to MUON_MOMENTUM.Adam β₁ (first-moment decay). Applies to all Adam optimizer groups.
Adam β₂ (second-moment decay). Applies to all Adam optimizer groups.
Adam numerical stability epsilon. Applies to all Adam optimizer groups.
Global gradient norm clip threshold. Set to
0.0 to disable gradient clipping.Quantization
These variables control which tensors are kept in floating-point during int8 post-training quantization.Comma-separated list of name substrings. Any parameter whose name contains one of these patterns is treated as a “control tensor” — kept in fp32 during training and excluded from int8 quantization. These are typically low-dimensional scalar/vector parameters that are sensitive to precision loss.
Comma-separated list of name substrings. Tensors matching these patterns are kept in full fp32 in the quantized artifact rather than being downcast to fp16. Defaults to the same value as
CONTROL_TENSOR_NAME_PATTERNS.Tensors with 65,536 elements or fewer are always kept as floating-point (stored as fp16) rather than quantized to int8, regardless of these patterns. Large 2D float tensors use per-row int8 quantization; other large float tensors use per-tensor int8 quantization.
Distributed Training
These variables are set automatically bytorchrun. You do not need to set them manually.
Global rank of the current process across all nodes. Process 0 is the master process that writes logs and saves checkpoints.
Total number of processes in the distributed job. Must divide 8 so that gradient accumulation steps (
8 // WORLD_SIZE) remain an integer. Valid values: 1, 2, 4, 8.Rank of the current process on its local node. Used to select the CUDA device (
cuda:<LOCAL_RANK>).Data Pipeline
These variables configure the dataset download and tokenization scripts indata/.
Hugging Face dataset repository ID to download shards and tokenizers from.
Subdirectory prefix within the HF repo under which dataset shards and manifest are stored.
Batch size for SentencePiece tokenizer encoding during shard export. Useful for tuning CPU-heavy export throughput.
Number of threads for the tokenizer encoding pool during shard export.
Number of threads for tiktoken encoding during shard export (used when tokenizing with the tiktoken backend).
Batch size for GPT-2 decoding during the blobstore docs-cache path. Useful for tuning memory vs. throughput tradeoff.
MLX-Only Variables
These variables are specific totrain_gpt_mlx.py and have no effect on train_gpt.py.
Maximum tokens per sub-batch within each logical microbatch. MLX splits each microbatch into smaller chunks of at most this size to reduce peak memory pressure on Apple Silicon without changing the effective optimizer batch size.
Number of gradient accumulation steps per optimizer update in
train_gpt_mlx.py. In train_gpt.py this is always derived as 8 // WORLD_SIZE and is not independently configurable.Output directory for log files and model artifacts in
train_gpt_mlx.py. In train_gpt.py the log directory is always logs/ and is not configurable.Number of tokens per logit computation chunk in
train_gpt_mlx.py. Set to a positive value to reduce peak memory by computing the final projection and cross-entropy loss in chunks. 0 (default) computes all tokens in a single matmul.