Overview
Thenanochat.checkpoint_manager module provides utilities for saving and loading model, optimizer, and metadata checkpoints during training.
save_checkpoint
Save model, optimizer, and metadata to checkpoint directory.Parameters
Directory to save checkpoint files
Training step number (used in filename)
Model state dict to save
Optimizer state dict to save. If
None, optimizer state is not saved.Note: In distributed training, optimizer state is sharded across ranks, so each rank saves its own shard.Metadata dict to save as JSON (e.g., model config, training hyperparameters)
Rank of the current process (for distributed training)
Saved Files
model_{step:06d}.pt- Model parameters (saved only on rank 0)meta_{step:06d}.json- Metadata JSON (saved only on rank 0)optim_{step:06d}_rank{rank}.pt- Optimizer state shard for each rank
Example
load_checkpoint
Load model, optimizer, and metadata from checkpoint directory.Parameters
Directory containing checkpoint files
Training step number to load
Device to map loaded tensors to
Whether to load optimizer state
Rank of the current process (for loading optimizer shard)
Returns
Model state dict
Optimizer state dict (or
None if load_optimizer=False)Metadata dict loaded from JSON
Example
Utility Functions
build_model
Build a model from a checkpoint.checkpoint_dir(str): Directory containing checkpointsstep(int): Training step to loaddevice(torch.device): Device to load model onphase(str): Either'train'or'eval'
find_largest_model
Find the largest model in a checkpoints directory.d12 > d6).
find_last_step
Find the last checkpoint step in a directory.model_*.pt files and returns the highest step number.
load_model_from_dir
Convenience function to load a model with automatic model tag and step detection.load_model
Load a model from standard nanochat directories.source(str): One of'base','sft', or'rl'- Other parameters same as
load_model_from_dir
load_optimizer_state
Load just the optimizer shard for a given rank.Notes
- The module handles backward compatibility by patching missing config keys from older checkpoints
- CPU/MPS devices automatically convert bfloat16 tensors to float32
- Handles
torch.compileartifacts by removing_orig_mod.prefixes from state dict keys