Skip to main content
Pretrain a base language model from scratch using distributed training.

Usage

# Single GPU
python -m scripts.base_train

# Distributed (8 GPUs)
torchrun --nproc_per_node=8 -m scripts.base_train

# CPU/Macbook (small model)
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 \
  --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20

Parameters

Logging

--run
str
default:"dummy"
Weights & Biases run name. Use 'dummy' to disable wandb logging.

Runtime

--device-type
str
default:""
Device type: cuda, cpu, or mps. Empty string enables autodetection.

FP8 Training

--fp8
bool
default:"false"
Enable FP8 training. Requires H100+ GPU and torchao.
--fp8-recipe
str
default:"tensorwise"
FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower).

Model Architecture

--depth
int
default:"20"
Depth of the Transformer model (number of layers).
--aspect-ratio
int
default:"64"
Model dimension is calculated as depth * aspect_ratio.
--head-dim
int
default:"128"
Target head dimension for attention.
--max-seq-len
int
default:"2048"
Maximum context length (sequence length).
--window-pattern
str
default:"SSSL"
Sliding window pattern tiled across layers. L = full context, S = half context (e.g. 'SSL').

Training Horizon

Only one is used, in order of precedence:
--num-iterations
int
default:"-1"
Explicit number of optimization steps. -1 = disabled.
--target-flops
float
default:"-1.0"
Calculate num_iterations to reach target FLOPs. -1 = disabled.
--target-param-data-ratio
float
default:"10.5"
Calculate num_iterations to maintain data:param ratio. Chinchilla = 20. -1 = disabled.

Optimization

--device-batch-size
int
default:"32"
Per-device batch size. Reduce to 16, 8, 4, etc. if you encounter OOM errors.
--total-batch-size
int
default:"-1"
Total batch size in tokens. Good value: 524288. -1 = auto-compute optimal.
--embedding-lr
float
default:"0.3"
Learning rate for embedding parameters (Adam).
--unembedding-lr
float
default:"0.004"
Learning rate for unembedding parameters (Adam).
--weight-decay
float
default:"0.2"
Cautious weight decay for the Muon optimizer (for weights).
--matrix-lr
float
default:"0.02"
Learning rate for matrix parameters (Muon).
--scalar-lr
float
default:"0.5"
Learning rate for scalars (resid_lambdas, x0_lambdas).
--adam-beta1
float
default:"0.8"
Adam beta1 for embedding/unembedding.
--adam-beta2
float
default:"0.95"
Adam beta2 for embedding/unembedding.
--warmup-ratio
float
default:"0.0"
Ratio of iterations for learning rate warmup.
--warmdown-ratio
float
default:"0.5"
Ratio of iterations for learning rate warmdown.
--final-lr-frac
float
default:"0.0"
Final learning rate as fraction of initial learning rate.
--resume-from-step
int
default:"-1"
Resume training from this step. -1 = disabled.

Evaluation

--eval-every
int
default:"250"
Evaluate validation bits-per-byte every N steps. -1 = disabled.
--eval-tokens
int
default:"20971520"
Number of tokens to evaluate validation loss on (default: 40*524288).
--core-metric-every
int
default:"2000"
Evaluate CORE metric every N steps. -1 = disabled.
--core-metric-max-per-task
int
default:"500"
Maximum examples per task for CORE metric.
--sample-every
int
default:"2000"
Sample from model every N steps. -1 = disabled.
--save-every
int
default:"-1"
Save checkpoints every N steps. -1 = only save at end.

Output

--model-tag
str
default:"None"
Override model tag for checkpoint directory name. If not provided, defaults to d{depth} (e.g. d12).

Examples

Full Training Run

torchrun --nproc_per_node=8 -m scripts.base_train \
  --run=my-experiment \
  --depth=24 \
  --max-seq-len=2048 \
  --device-batch-size=32 \
  --eval-every=500

Resume Training

torchrun --nproc_per_node=8 -m scripts.base_train \
  --model-tag=d24 \
  --resume-from-step=5000

FP8 Training (H100+)

torchrun --nproc_per_node=8 -m scripts.base_train \
  --fp8 \
  --fp8-recipe=tensorwise

Build docs developers (and LLMs) love