Skip to main content
Modern LLM provides robust checkpoint management for saving and loading model states, optimizer states, and training metadata. Checkpoints support compiled models, distributed training, and flexible recovery.

Saving checkpoints

Use the save_checkpoint function to persist model and optimizer state:
from pathlib import Path
import torch
from modern_llm.utils.checkpointing import save_checkpoint
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig

# Create model
config = ModernLLMConfig(
    vocab_size=50257,
    d_model=768,
    n_layers=12,
    n_heads=12,
    ffn_hidden_size=3072,
    max_seq_len=1024,
)
model = ModernDecoderLM(config)

# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# Save checkpoint
save_checkpoint(
    path=Path("experiments/my-run/checkpoint_1000.pt"),
    model_state=model.state_dict(),
    optimizer_state=optimizer.state_dict(),
    step=1000,
    loss=2.45,
    config=config.__dict__,
)
The save_checkpoint function automatically creates parent directories if they don’t exist.

Checkpoint structure

Checkpoints are PyTorch .pt files containing:
{
    "model_state": {...},      # model.state_dict()
    "optimizer": {...},        # optimizer.state_dict() (optional)
    "metadata": {
        "step": 1000,
        "loss": 2.45,
        "config": {...},
        # ... any additional metadata
    },
    # Metadata also available at top level
    "step": 1000,
    "loss": 2.45,
    "config": {...},
}

Saving with metadata

Include arbitrary metadata using keyword arguments:
save_checkpoint(
    path=Path("experiments/checkpoint_5000.pt"),
    model_state=model.state_dict(),
    optimizer_state=optimizer.state_dict(),
    # Metadata
    step=5000,
    epoch=2,
    train_loss=1.87,
    val_loss=2.12,
    learning_rate=2.5e-4,
    timestamp="2024-03-15T10:30:00",
    config=config.__dict__,
    git_commit="a3f4c2d",
)

Saving during training

Typical training loop integration:
from pathlib import Path
from modern_llm.utils.checkpointing import save_checkpoint

output_dir = Path("experiments/my-training-run")
output_dir.mkdir(parents=True, exist_ok=True)

for step in range(max_steps):
    # Training step
    loss = train_step(model, batch, optimizer)
    
    # Save checkpoint every N steps
    if step % save_every == 0:
        checkpoint_path = output_dir / f"checkpoint_{step}.pt"
        save_checkpoint(
            path=checkpoint_path,
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict(),
            step=step,
            loss=loss.item(),
            config=config.__dict__,
        )
        print(f"Saved checkpoint to {checkpoint_path}")

# Save final checkpoint
final_path = output_dir / f"{run_name}_final.pt"
save_checkpoint(
    path=final_path,
    model_state=model.state_dict(),
    optimizer_state=optimizer.state_dict(),
    step=step,
    loss=loss.item(),
    config=config.__dict__,
    final=True,
)

Loading checkpoints

Use the load_checkpoint function to restore saved state:
from pathlib import Path
from modern_llm.utils.checkpointing import load_checkpoint

# Load checkpoint
checkpoint_path = Path("experiments/my-run/checkpoint_1000.pt")
checkpoint = load_checkpoint(checkpoint_path)

print(f"Checkpoint keys: {checkpoint.keys()}")
print(f"Training step: {checkpoint['step']}")
print(f"Training loss: {checkpoint['loss']}")

Restoring model state

from modern_llm.utils.checkpointing import load_checkpoint
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig

# Load checkpoint
checkpoint = load_checkpoint(Path("checkpoint.pt"))

# Recreate model from saved config
config = ModernLLMConfig(**checkpoint["config"])
model = ModernDecoderLM(config)

# Load model weights
model.load_state_dict(checkpoint["model_state"])
model.eval()

print(f"Loaded model from step {checkpoint['step']}")

Resuming training

Restore both model and optimizer state:
import torch
from modern_llm.utils.checkpointing import load_checkpoint
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig

# Load checkpoint
checkpoint = load_checkpoint(Path("checkpoint_5000.pt"))

# Restore model
config = ModernLLMConfig(**checkpoint["config"])
model = ModernDecoderLM(config)
model.load_state_dict(checkpoint["model_state"])

# Restore optimizer
optimizer = torch.optim.AdamW(model.parameters())
if "optimizer" in checkpoint:
    optimizer.load_state_dict(checkpoint["optimizer"])

# Resume training from saved step
start_step = checkpoint["step"] + 1
for step in range(start_step, max_steps):
    # Continue training...
    pass
# Resume with all state
checkpoint = load_checkpoint(Path("checkpoint.pt"))

model = ModernDecoderLM(ModernLLMConfig(**checkpoint["config"]))
model.load_state_dict(checkpoint["model_state"])

optimizer = torch.optim.AdamW(model.parameters())
optimizer.load_state_dict(checkpoint["optimizer"])

start_step = checkpoint["step"] + 1

Compiled model support

Modern LLM automatically handles checkpoints from torch.compile:
import torch
from modern_llm.utils.checkpointing import save_checkpoint, load_checkpoint
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig

# Create and compile model
config = ModernLLMConfig(...)
model = ModernDecoderLM(config)
compiled_model = torch.compile(model)

# Train with compiled model...
# State dict keys have "_orig_mod." prefix

# Save (prefix is automatically stripped)
save_checkpoint(
    path=Path("compiled_checkpoint.pt"),
    model_state=compiled_model.state_dict(),
)

# Load into non-compiled model (works seamlessly)
checkpoint = load_checkpoint(Path("compiled_checkpoint.pt"))
model_new = ModernDecoderLM(config)
model_new.load_state_dict(checkpoint["model_state"])  # No prefix issues
When you use torch.compile, PyTorch wraps your model and prefixes all parameter keys with "_orig_mod.":
# Compiled model state_dict:
{
    "_orig_mod.embeddings.weight": ...,
    "_orig_mod.layers.0.attention.q_proj.weight": ...,
    # ...
}
The save_checkpoint function calls _strip_orig_mod_prefix() to normalize these keys:
# Normalized state_dict:
{
    "embeddings.weight": ...,
    "layers.0.attention.q_proj.weight": ...,
    # ...
}
This ensures checkpoints are loadable into both compiled and non-compiled models.

Checkpoint management patterns

Keep best checkpoint

from pathlib import Path
from modern_llm.utils.checkpointing import save_checkpoint

best_val_loss = float("inf")
best_checkpoint_path = None

for step in range(max_steps):
    train_loss = train_step(...)
    
    if step % eval_every == 0:
        val_loss = evaluate(...)
        
        # Save if best so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
            # Remove old best
            if best_checkpoint_path and best_checkpoint_path.exists():
                best_checkpoint_path.unlink()
            
            # Save new best
            best_checkpoint_path = output_dir / f"best_checkpoint.pt"
            save_checkpoint(
                path=best_checkpoint_path,
                model_state=model.state_dict(),
                optimizer_state=optimizer.state_dict(),
                step=step,
                train_loss=train_loss,
                val_loss=val_loss,
            )
            print(f"New best model: val_loss={val_loss:.4f}")

Keep last N checkpoints

from pathlib import Path
from collections import deque
from modern_llm.utils.checkpointing import save_checkpoint

keep_last_n = 3
checkpoint_queue = deque(maxlen=keep_last_n)

for step in range(max_steps):
    train_loss = train_step(...)
    
    if step % save_every == 0:
        checkpoint_path = output_dir / f"checkpoint_{step}.pt"
        save_checkpoint(
            path=checkpoint_path,
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict(),
            step=step,
            loss=train_loss,
        )
        
        # Add to queue (automatically removes oldest if full)
        checkpoint_queue.append(checkpoint_path)
        
        # Clean up old checkpoints
        for old_path in output_dir.glob("checkpoint_*.pt"):
            if old_path not in checkpoint_queue:
                old_path.unlink()
                print(f"Deleted old checkpoint: {old_path}")

Stage-based checkpoints

from pathlib import Path
from modern_llm.utils.checkpointing import save_checkpoint

def save_stage_checkpoint(stage: str, model, optimizer, step: int, **metadata):
    """Save checkpoint for a specific training stage."""
    checkpoint_dir = Path(f"experiments/{run_name}-{stage}")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    checkpoint_path = checkpoint_dir / f"{run_name}-{stage}_final.pt"
    save_checkpoint(
        path=checkpoint_path,
        model_state=model.state_dict(),
        optimizer_state=optimizer.state_dict(),
        stage=stage,
        step=step,
        **metadata,
    )
    return checkpoint_path

# Pretrain stage
pretrain_ckpt = save_stage_checkpoint(
    "pretrain", model, optimizer, step=20000, loss=2.1
)

# SFT stage
sft_ckpt = save_stage_checkpoint(
    "sft", model, optimizer, step=5000, loss=1.5
)

# DPO stage
dpo_ckpt = save_stage_checkpoint(
    "dpo", model, optimizer, step=2000, loss=0.8
)

Error handling

Handle missing or corrupted checkpoints:
from pathlib import Path
from modern_llm.utils.checkpointing import load_checkpoint

checkpoint_path = Path("checkpoint.pt")

try:
    checkpoint = load_checkpoint(checkpoint_path)
    print(f"Loaded checkpoint from step {checkpoint['step']}")
except FileNotFoundError:
    print(f"Checkpoint not found: {checkpoint_path}")
    # Start training from scratch
    checkpoint = None
except Exception as e:
    print(f"Error loading checkpoint: {e}")
    # Fall back to earlier checkpoint or start fresh
    checkpoint = None

if checkpoint:
    model.load_state_dict(checkpoint["model_state"])
    start_step = checkpoint["step"] + 1
else:
    start_step = 0

Complete example

Full training loop with checkpoint management:
from pathlib import Path
import torch
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig, TrainingConfig
from modern_llm.utils.checkpointing import save_checkpoint, load_checkpoint

# Configuration
model_config = ModernLLMConfig(
    vocab_size=50257,
    d_model=768,
    n_layers=12,
    n_heads=12,
    ffn_hidden_size=3072,
    max_seq_len=1024,
)

train_config = TrainingConfig(
    run_name="my-experiment",
    dataset_name="wikitext",
    tokenizer_name="gpt2",
    output_dir=Path("experiments/my-experiment"),
    batch_size=64,
    micro_batch_size=8,
    gradient_accumulation_steps=8,
    learning_rate=3e-4,
    max_steps=10000,
    save_every=1000,
)

# Initialize or resume
model = ModernDecoderLM(model_config)
optimizer = torch.optim.AdamW(model.parameters(), lr=train_config.learning_rate)
start_step = 0

# Try to resume from latest checkpoint
latest_checkpoint = train_config.output_dir / "latest_checkpoint.pt"
if latest_checkpoint.exists():
    print(f"Resuming from {latest_checkpoint}")
    checkpoint = load_checkpoint(latest_checkpoint)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    start_step = checkpoint["step"] + 1
else:
    print("Starting training from scratch")

# Training loop
for step in range(start_step, train_config.max_steps):
    # Training step
    loss = train_step(model, batch, optimizer)
    
    # Save checkpoint
    if step % train_config.save_every == 0:
        save_checkpoint(
            path=latest_checkpoint,
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict(),
            step=step,
            loss=loss.item(),
            config=model_config.__dict__,
            learning_rate=optimizer.param_groups[0]["lr"],
        )
        print(f"Checkpoint saved at step {step}")

# Save final checkpoint
final_checkpoint = train_config.output_dir / f"{train_config.run_name}_final.pt"
save_checkpoint(
    path=final_checkpoint,
    model_state=model.state_dict(),
    optimizer_state=optimizer.state_dict(),
    step=step,
    loss=loss.item(),
    config=model_config.__dict__,
    final=True,
)
print(f"Training complete. Final checkpoint: {final_checkpoint}")

See also

Build docs developers (and LLMs) love