Skip to main content

Overview

The nanochat.checkpoint_manager module provides utilities for saving and loading model, optimizer, and metadata checkpoints during training.

save_checkpoint

Save model, optimizer, and metadata to checkpoint directory.
def save_checkpoint(
    checkpoint_dir: str,
    step: int,
    model_data: dict,
    optimizer_data: dict,
    meta_data: dict,
    rank: int = 0
)

Parameters

checkpoint_dir
str
required
Directory to save checkpoint files
step
int
required
Training step number (used in filename)
model_data
dict
required
Model state dict to save
optimizer_data
dict
Optimizer state dict to save. If None, optimizer state is not saved.Note: In distributed training, optimizer state is sharded across ranks, so each rank saves its own shard.
meta_data
dict
required
Metadata dict to save as JSON (e.g., model config, training hyperparameters)
rank
int
default:"0"
Rank of the current process (for distributed training)

Saved Files

  • model_{step:06d}.pt - Model parameters (saved only on rank 0)
  • meta_{step:06d}.json - Metadata JSON (saved only on rank 0)
  • optim_{step:06d}_rank{rank}.pt - Optimizer state shard for each rank

Example

from nanochat.checkpoint_manager import save_checkpoint

# Prepare checkpoint data
model_data = model.state_dict()
optimizer_data = optimizer.state_dict()
meta_data = {
    'model_config': model.config.__dict__,
    'step': step,
    'learning_rate': lr
}

# Save checkpoint
save_checkpoint(
    checkpoint_dir='checkpoints/d12',
    step=1000,
    model_data=model_data,
    optimizer_data=optimizer_data,
    meta_data=meta_data,
    rank=0
)

load_checkpoint

Load model, optimizer, and metadata from checkpoint directory.
def load_checkpoint(
    checkpoint_dir: str,
    step: int,
    device: torch.device,
    load_optimizer: bool = False,
    rank: int = 0
) -> tuple[dict, dict | None, dict]

Parameters

checkpoint_dir
str
required
Directory containing checkpoint files
step
int
required
Training step number to load
device
torch.device
required
Device to map loaded tensors to
load_optimizer
bool
default:"False"
Whether to load optimizer state
rank
int
default:"0"
Rank of the current process (for loading optimizer shard)

Returns

model_data
dict
Model state dict
optimizer_data
dict | None
Optimizer state dict (or None if load_optimizer=False)
meta_data
dict
Metadata dict loaded from JSON

Example

import torch
from nanochat.checkpoint_manager import load_checkpoint

# Load checkpoint
model_data, optimizer_data, meta_data = load_checkpoint(
    checkpoint_dir='checkpoints/d12',
    step=1000,
    device=torch.device('cuda'),
    load_optimizer=True,
    rank=0
)

# Restore model and optimizer
model.load_state_dict(model_data)
if optimizer_data is not None:
    optimizer.load_state_dict(optimizer_data)

Utility Functions

build_model

Build a model from a checkpoint.
def build_model(
    checkpoint_dir: str,
    step: int,
    device: torch.device,
    phase: str
) -> tuple[GPT, Tokenizer, dict]
Parameters:
  • checkpoint_dir (str): Directory containing checkpoints
  • step (int): Training step to load
  • device (torch.device): Device to load model on
  • phase (str): Either 'train' or 'eval'
Returns: (model, tokenizer, meta_data)

find_largest_model

Find the largest model in a checkpoints directory.
def find_largest_model(checkpoints_dir: str) -> str
Attempts to find the model tag with the largest depth (e.g., d12 > d6).

find_last_step

Find the last checkpoint step in a directory.
def find_last_step(checkpoint_dir: str) -> int
Scans for model_*.pt files and returns the highest step number.

load_model_from_dir

Convenience function to load a model with automatic model tag and step detection.
def load_model_from_dir(
    checkpoints_dir: str,
    device: torch.device,
    phase: str,
    model_tag: str | None = None,
    step: int | None = None
) -> tuple[GPT, Tokenizer, dict]

load_model

Load a model from standard nanochat directories.
def load_model(
    source: str,
    device: torch.device,
    phase: str,
    model_tag: str | None = None,
    step: int | None = None
) -> tuple[GPT, Tokenizer, dict]
Parameters:
  • source (str): One of 'base', 'sft', or 'rl'
  • Other parameters same as load_model_from_dir
Returns: (model, tokenizer, meta_data)

load_optimizer_state

Load just the optimizer shard for a given rank.
def load_optimizer_state(
    source: str,
    device: torch.device,
    rank: int,
    model_tag: str | None = None,
    step: int | None = None
) -> dict | None
Useful for resuming distributed training without reloading the model.

Notes

  • The module handles backward compatibility by patching missing config keys from older checkpoints
  • CPU/MPS devices automatically convert bfloat16 tensors to float32
  • Handles torch.compile artifacts by removing _orig_mod. prefixes from state dict keys

Build docs developers (and LLMs) love