Skip to main content

Overview

nanochat’s checkpoint system saves model parameters, optimizer state, and training metadata to enable resuming training, evaluation, and deployment. Checkpoints are organized by model size and training step.

Checkpoint Structure

Checkpoints are stored in base_checkpoints/<model_tag>/:
base_checkpoints/
└── d26/                          # Model tag (depth 26)
    ├── model_000500.pt           # Model parameters at step 500
    ├── meta_000500.json          # Metadata (config, metrics)
    ├── optim_000500_rank0.pt     # Optimizer state (rank 0)
    ├── optim_000500_rank1.pt     # Optimizer state (rank 1)
    ├── ...
    └── optim_000500_rank7.pt     # Optimizer state (rank 7)

File Contents

model_*.pt: PyTorch state dict with all model parameters
{
    'wte.weight': tensor(...),      # Token embeddings
    'h.0.ln_1.weight': tensor(...), # Layer 0 norm
    'h.0.attn.qkv.weight': tensor(...),
    # ... all transformer layers ...
    'lm_head.weight': tensor(...),  # Output projection
    'resid_lambdas': tensor(...),   # Residual scalars
    'x0_lambdas': tensor(...),      # X0 scalars
}
meta_*.json: Training metadata
{
  "step": 5000,
  "val_bpb": 0.74504,
  "model_config": {
    "n_layer": 26,
    "n_embd": 1664,
    "n_head": 13,
    "sequence_len": 2048,
    "vocab_size": 33280
  },
  "user_config": {
    "depth": 26,
    "device_batch_size": 32,
    "fp8": true
  },
  "dataloader_state_dict": {...},
  "loop_state": {
    "min_val_bpb": 0.74123,
    "smooth_train_loss": 2.1234,
    "total_training_time": 7200.5
  }
}
optim_*_rank*.pt: Per-rank optimizer state (momentum buffers, etc.)

Saving Checkpoints

During Training

Checkpoints are saved automatically during training based on the --save-every flag:
# Save checkpoint every 1000 steps
python -m scripts.base_train --depth=12 --save-every=1000

# Save only at the end (-1)
python -m scripts.base_train --depth=12 --save-every=-1
From scripts/base_train.py:460-483:
if last_step or (step > 0 and step != args.resume_from_step and 
                 args.save_every > 0 and step % args.save_every == 0):
    save_checkpoint(
        checkpoint_dir,
        step,
        orig_model.state_dict(),
        optimizer.state_dict(),
        {
            "step": step,
            "val_bpb": val_bpb,
            "model_config": model_config_kwargs,
            "user_config": user_config,
            "dataloader_state_dict": dataloader_state_dict,
            "loop_state": {...},
        },
        rank=ddp_rank,
    )

Manual Saving

Use the save_checkpoint() function from Python:
from nanochat.checkpoint_manager import save_checkpoint

save_checkpoint(
    checkpoint_dir="base_checkpoints/d12",
    step=5000,
    model_data=model.state_dict(),
    optimizer_data=optimizer.state_dict(),
    meta_data={
        "step": 5000,
        "model_config": asdict(model.config),
        "val_bpb": 0.75123,
    },
    rank=0,  # Only rank 0 saves model/meta, all ranks save optimizer
)
From nanochat/checkpoint_manager.py:42-59.

Loading Checkpoints

For Resuming Training

Resume training from a specific step:
torchrun --nproc_per_node=8 -m scripts.base_train \
    --depth=26 \
    --resume-from-step=5000
The script automatically:
  1. Loads model parameters
  2. Loads optimizer state (per-rank shards)
  3. Restores dataloader position
  4. Restores training metrics (loss EMA, time, etc.)
From scripts/base_train.py:154-158:
if resuming:
    model_data, optimizer_data, meta_data = load_checkpoint(
        checkpoint_dir, args.resume_from_step, device, 
        load_optimizer=True, rank=ddp_rank
    )
    model.load_state_dict(model_data, strict=True, assign=True)
    optimizer.load_state_dict(optimizer_data)

For Evaluation

Load a model for evaluation without optimizer state:
from nanochat.checkpoint_manager import load_model

# Load the largest available model from base_checkpoints/
model, tokenizer, meta_data = load_model(
    source="base",      # "base", "sft", or "rl"
    device=device,
    phase="eval",       # "eval" or "train"
    model_tag=None,     # Auto-detect largest model
    step=None,          # Auto-detect last step
)
From nanochat/checkpoint_manager.py:164-172.

Advanced: Direct Loading

For full control, use load_checkpoint() directly:
from nanochat.checkpoint_manager import load_checkpoint

model_data, optimizer_data, meta_data = load_checkpoint(
    checkpoint_dir="base_checkpoints/d26",
    step=10000,
    device=device,
    load_optimizer=False,  # Skip optimizer for inference
    rank=0,
)
From nanochat/checkpoint_manager.py:61-74.

Checkpoint Naming

Model Tags

By default, checkpoints use d{depth} as the model tag:
# Creates base_checkpoints/d12/
python -m scripts.base_train --depth=12

# Creates base_checkpoints/d26/
python -m scripts.base_train --depth=26
Override with --model-tag:
# Creates base_checkpoints/my_experiment/
python -m scripts.base_train --depth=12 --model-tag=my_experiment
From scripts/base_train.py:151-152:
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}"
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)

Step Numbers

Step numbers are zero-padded to 6 digits:
model_000500.pt   # Step 500
model_005000.pt   # Step 5000  
model_010000.pt   # Step 10000

Checkpoint Utilities

Find Latest Checkpoint

from nanochat.checkpoint_manager import find_last_step

last_step = find_last_step("base_checkpoints/d26")
print(f"Last checkpoint: step {last_step}")
From nanochat/checkpoint_manager.py:138-144.

Find Largest Model

from nanochat.checkpoint_manager import find_largest_model

model_tag = find_largest_model("base_checkpoints")
print(f"Largest model: {model_tag}")  # e.g., "d26"
From nanochat/checkpoint_manager.py:118-135. Attempts to parse d{depth} format, falls back to most recently modified.

Build Model from Checkpoint

from nanochat.checkpoint_manager import build_model

model, tokenizer, meta_data = build_model(
    checkpoint_dir="base_checkpoints/d26",
    step=10000,
    device=device,
    phase="eval",
)
From nanochat/checkpoint_manager.py:77-115. Handles:
  • Meta device initialization (no memory allocation)
  • Weight loading with assign=True (zero-copy)
  • BF16 → FP32 conversion for CPU/MPS
  • Backward compatibility (patches missing config keys)

Backward Compatibility

nanochat patches missing parameters when loading old checkpoints:

Missing Config Keys

if "window_pattern" not in model_config_kwargs:
    model_config_kwargs["window_pattern"] = "L"  # Full context
From nanochat/checkpoint_manager.py:23-28.

Missing Model Parameters

if "resid_lambdas" not in model_data:
    model_data["resid_lambdas"] = torch.ones(n_layer)  # Identity
if "x0_lambdas" not in model_data:
    model_data["x0_lambdas"] = torch.zeros(n_layer)  # Disabled
From nanochat/checkpoint_manager.py:30-40.

Disk Usage

Checkpoint sizes depend on model depth:
DepthParametersmodel_*.ptoptim_*.pt (per rank)Total (8 ranks)
d12~120M~480 MB~960 MB~8.2 GB
d20~330M~1.3 GB~2.6 GB~22 GB
d26~570M~2.3 GB~4.6 GB~39 GB
Optimizer state is 2x model size (stores momentum + variance for Adam/Muon).

Best Practices

During Experimentation

# Save only at end to save disk space
python -m scripts.base_train --depth=12 --save-every=-1

During Long Runs

# Save every 2000 steps for safety
torchrun --nproc_per_node=8 -m scripts.base_train \
    --depth=26 \
    --save-every=2000

For Reproducibility

Always save:
  • Model parameters (model_*.pt)
  • Metadata (meta_*.json) with full config
  • Training script and commit hash

Checkpoint Cleanup

Delete intermediate checkpoints to save space:
# Keep only the last checkpoint
cd base_checkpoints/d26
ls -t model_*.pt | tail -n +2 | xargs rm  # Keep newest, delete rest
ls -t meta_*.json | tail -n +2 | xargs rm
ls -t optim_*.pt | tail -n +9 | xargs rm  # Keep last step (8 ranks)

Checkpoint Migration

Move checkpoints between systems:

Copy to Another Machine

# Copy specific checkpoint
scp -r base_checkpoints/d26 user@remote:/path/to/nanochat/base_checkpoints/

# Copy only model (skip optimizer for inference)
scp base_checkpoints/d26/model_010000.pt \
    base_checkpoints/d26/meta_010000.json \
    user@remote:/path/to/nanochat/base_checkpoints/d26/

Resume on Different GPU Count

nanochat handles different GPU counts automatically:
# Trained on 8 GPUs, resume on 4 GPUs
torchrun --nproc_per_node=4 -m scripts.base_train \
    --depth=26 \
    --resume-from-step=5000
Optimizer state is re-initialized if rank count changes.

Troubleshooting

”No checkpoints found”

Ensure checkpoint directory exists:
ls base_checkpoints/d26/
# Should show model_*.pt files

“Checkpoint version mismatch”

Backward compatibility patches should handle most cases. If not:
# Load with strict=False to skip missing keys
model.load_state_dict(model_data, strict=False)

“Out of memory when loading”

Use meta device for zero-memory loading:
with torch.device("meta"):
    model = GPT(config)  # No memory allocated
model.to_empty(device=device)  # Allocate storage
model.load_state_dict(model_data, assign=True)  # Zero-copy load
From nanochat/checkpoint_manager.py:100-105.

”Optimizer state shape mismatch”

This happens when model architecture changes. Re-initialize optimizer:
# Resume without loading optimizer (will restart optimization)
model_data, _, meta_data = load_checkpoint(
    checkpoint_dir, step, device, load_optimizer=False
)

Further Reading

  • nanochat/checkpoint_manager.py - Full checkpoint implementation
  • scripts/base_train.py:460-483 - Checkpoint saving during training
  • scripts/base_train.py:154-158 - Checkpoint loading for resume
  • PyTorch Saving & Loading

Build docs developers (and LLMs) love