Saving checkpoints
Use thesave_checkpoint function to persist model and optimizer state:
from pathlib import Path
import torch
from modern_llm.utils.checkpointing import save_checkpoint
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig
# Create model
config = ModernLLMConfig(
vocab_size=50257,
d_model=768,
n_layers=12,
n_heads=12,
ffn_hidden_size=3072,
max_seq_len=1024,
)
model = ModernDecoderLM(config)
# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# Save checkpoint
save_checkpoint(
path=Path("experiments/my-run/checkpoint_1000.pt"),
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
step=1000,
loss=2.45,
config=config.__dict__,
)
The
save_checkpoint function automatically creates parent directories if they don’t exist.Checkpoint structure
Checkpoints are PyTorch.pt files containing:
{
"model_state": {...}, # model.state_dict()
"optimizer": {...}, # optimizer.state_dict() (optional)
"metadata": {
"step": 1000,
"loss": 2.45,
"config": {...},
# ... any additional metadata
},
# Metadata also available at top level
"step": 1000,
"loss": 2.45,
"config": {...},
}
Saving with metadata
Include arbitrary metadata using keyword arguments:save_checkpoint(
path=Path("experiments/checkpoint_5000.pt"),
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
# Metadata
step=5000,
epoch=2,
train_loss=1.87,
val_loss=2.12,
learning_rate=2.5e-4,
timestamp="2024-03-15T10:30:00",
config=config.__dict__,
git_commit="a3f4c2d",
)
Saving during training
Typical training loop integration:from pathlib import Path
from modern_llm.utils.checkpointing import save_checkpoint
output_dir = Path("experiments/my-training-run")
output_dir.mkdir(parents=True, exist_ok=True)
for step in range(max_steps):
# Training step
loss = train_step(model, batch, optimizer)
# Save checkpoint every N steps
if step % save_every == 0:
checkpoint_path = output_dir / f"checkpoint_{step}.pt"
save_checkpoint(
path=checkpoint_path,
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
step=step,
loss=loss.item(),
config=config.__dict__,
)
print(f"Saved checkpoint to {checkpoint_path}")
# Save final checkpoint
final_path = output_dir / f"{run_name}_final.pt"
save_checkpoint(
path=final_path,
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
step=step,
loss=loss.item(),
config=config.__dict__,
final=True,
)
Loading checkpoints
Use theload_checkpoint function to restore saved state:
from pathlib import Path
from modern_llm.utils.checkpointing import load_checkpoint
# Load checkpoint
checkpoint_path = Path("experiments/my-run/checkpoint_1000.pt")
checkpoint = load_checkpoint(checkpoint_path)
print(f"Checkpoint keys: {checkpoint.keys()}")
print(f"Training step: {checkpoint['step']}")
print(f"Training loss: {checkpoint['loss']}")
Restoring model state
from modern_llm.utils.checkpointing import load_checkpoint
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig
# Load checkpoint
checkpoint = load_checkpoint(Path("checkpoint.pt"))
# Recreate model from saved config
config = ModernLLMConfig(**checkpoint["config"])
model = ModernDecoderLM(config)
# Load model weights
model.load_state_dict(checkpoint["model_state"])
model.eval()
print(f"Loaded model from step {checkpoint['step']}")
Resuming training
Restore both model and optimizer state:import torch
from modern_llm.utils.checkpointing import load_checkpoint
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig
# Load checkpoint
checkpoint = load_checkpoint(Path("checkpoint_5000.pt"))
# Restore model
config = ModernLLMConfig(**checkpoint["config"])
model = ModernDecoderLM(config)
model.load_state_dict(checkpoint["model_state"])
# Restore optimizer
optimizer = torch.optim.AdamW(model.parameters())
if "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
# Resume training from saved step
start_step = checkpoint["step"] + 1
for step in range(start_step, max_steps):
# Continue training...
pass
- Full resume
- Inference only
- Fine-tune
# Resume with all state
checkpoint = load_checkpoint(Path("checkpoint.pt"))
model = ModernDecoderLM(ModernLLMConfig(**checkpoint["config"]))
model.load_state_dict(checkpoint["model_state"])
optimizer = torch.optim.AdamW(model.parameters())
optimizer.load_state_dict(checkpoint["optimizer"])
start_step = checkpoint["step"] + 1
# Load for inference (no optimizer)
checkpoint = load_checkpoint(Path("checkpoint.pt"))
model = ModernDecoderLM(ModernLLMConfig(**checkpoint["config"]))
model.load_state_dict(checkpoint["model_state"])
model.eval()
# No optimizer needed for inference
# Load for fine-tuning (new optimizer)
checkpoint = load_checkpoint(Path("checkpoint.pt"))
model = ModernDecoderLM(ModernLLMConfig(**checkpoint["config"]))
model.load_state_dict(checkpoint["model_state"])
# Create new optimizer with different LR
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# Start from step 0 (new training run)
Compiled model support
Modern LLM automatically handles checkpoints fromtorch.compile:
import torch
from modern_llm.utils.checkpointing import save_checkpoint, load_checkpoint
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig
# Create and compile model
config = ModernLLMConfig(...)
model = ModernDecoderLM(config)
compiled_model = torch.compile(model)
# Train with compiled model...
# State dict keys have "_orig_mod." prefix
# Save (prefix is automatically stripped)
save_checkpoint(
path=Path("compiled_checkpoint.pt"),
model_state=compiled_model.state_dict(),
)
# Load into non-compiled model (works seamlessly)
checkpoint = load_checkpoint(Path("compiled_checkpoint.pt"))
model_new = ModernDecoderLM(config)
model_new.load_state_dict(checkpoint["model_state"]) # No prefix issues
How prefix stripping works
How prefix stripping works
When you use The This ensures checkpoints are loadable into both compiled and non-compiled models.
torch.compile, PyTorch wraps your model and prefixes all parameter keys with "_orig_mod.":# Compiled model state_dict:
{
"_orig_mod.embeddings.weight": ...,
"_orig_mod.layers.0.attention.q_proj.weight": ...,
# ...
}
save_checkpoint function calls _strip_orig_mod_prefix() to normalize these keys:# Normalized state_dict:
{
"embeddings.weight": ...,
"layers.0.attention.q_proj.weight": ...,
# ...
}
Checkpoint management patterns
Keep best checkpoint
from pathlib import Path
from modern_llm.utils.checkpointing import save_checkpoint
best_val_loss = float("inf")
best_checkpoint_path = None
for step in range(max_steps):
train_loss = train_step(...)
if step % eval_every == 0:
val_loss = evaluate(...)
# Save if best so far
if val_loss < best_val_loss:
best_val_loss = val_loss
# Remove old best
if best_checkpoint_path and best_checkpoint_path.exists():
best_checkpoint_path.unlink()
# Save new best
best_checkpoint_path = output_dir / f"best_checkpoint.pt"
save_checkpoint(
path=best_checkpoint_path,
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
step=step,
train_loss=train_loss,
val_loss=val_loss,
)
print(f"New best model: val_loss={val_loss:.4f}")
Keep last N checkpoints
from pathlib import Path
from collections import deque
from modern_llm.utils.checkpointing import save_checkpoint
keep_last_n = 3
checkpoint_queue = deque(maxlen=keep_last_n)
for step in range(max_steps):
train_loss = train_step(...)
if step % save_every == 0:
checkpoint_path = output_dir / f"checkpoint_{step}.pt"
save_checkpoint(
path=checkpoint_path,
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
step=step,
loss=train_loss,
)
# Add to queue (automatically removes oldest if full)
checkpoint_queue.append(checkpoint_path)
# Clean up old checkpoints
for old_path in output_dir.glob("checkpoint_*.pt"):
if old_path not in checkpoint_queue:
old_path.unlink()
print(f"Deleted old checkpoint: {old_path}")
Stage-based checkpoints
from pathlib import Path
from modern_llm.utils.checkpointing import save_checkpoint
def save_stage_checkpoint(stage: str, model, optimizer, step: int, **metadata):
"""Save checkpoint for a specific training stage."""
checkpoint_dir = Path(f"experiments/{run_name}-{stage}")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / f"{run_name}-{stage}_final.pt"
save_checkpoint(
path=checkpoint_path,
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
stage=stage,
step=step,
**metadata,
)
return checkpoint_path
# Pretrain stage
pretrain_ckpt = save_stage_checkpoint(
"pretrain", model, optimizer, step=20000, loss=2.1
)
# SFT stage
sft_ckpt = save_stage_checkpoint(
"sft", model, optimizer, step=5000, loss=1.5
)
# DPO stage
dpo_ckpt = save_stage_checkpoint(
"dpo", model, optimizer, step=2000, loss=0.8
)
Error handling
Handle missing or corrupted checkpoints:from pathlib import Path
from modern_llm.utils.checkpointing import load_checkpoint
checkpoint_path = Path("checkpoint.pt")
try:
checkpoint = load_checkpoint(checkpoint_path)
print(f"Loaded checkpoint from step {checkpoint['step']}")
except FileNotFoundError:
print(f"Checkpoint not found: {checkpoint_path}")
# Start training from scratch
checkpoint = None
except Exception as e:
print(f"Error loading checkpoint: {e}")
# Fall back to earlier checkpoint or start fresh
checkpoint = None
if checkpoint:
model.load_state_dict(checkpoint["model_state"])
start_step = checkpoint["step"] + 1
else:
start_step = 0
Complete example
Full training loop with checkpoint management:from pathlib import Path
import torch
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.config import ModernLLMConfig, TrainingConfig
from modern_llm.utils.checkpointing import save_checkpoint, load_checkpoint
# Configuration
model_config = ModernLLMConfig(
vocab_size=50257,
d_model=768,
n_layers=12,
n_heads=12,
ffn_hidden_size=3072,
max_seq_len=1024,
)
train_config = TrainingConfig(
run_name="my-experiment",
dataset_name="wikitext",
tokenizer_name="gpt2",
output_dir=Path("experiments/my-experiment"),
batch_size=64,
micro_batch_size=8,
gradient_accumulation_steps=8,
learning_rate=3e-4,
max_steps=10000,
save_every=1000,
)
# Initialize or resume
model = ModernDecoderLM(model_config)
optimizer = torch.optim.AdamW(model.parameters(), lr=train_config.learning_rate)
start_step = 0
# Try to resume from latest checkpoint
latest_checkpoint = train_config.output_dir / "latest_checkpoint.pt"
if latest_checkpoint.exists():
print(f"Resuming from {latest_checkpoint}")
checkpoint = load_checkpoint(latest_checkpoint)
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_step = checkpoint["step"] + 1
else:
print("Starting training from scratch")
# Training loop
for step in range(start_step, train_config.max_steps):
# Training step
loss = train_step(model, batch, optimizer)
# Save checkpoint
if step % train_config.save_every == 0:
save_checkpoint(
path=latest_checkpoint,
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
step=step,
loss=loss.item(),
config=model_config.__dict__,
learning_rate=optimizer.param_groups[0]["lr"],
)
print(f"Checkpoint saved at step {step}")
# Save final checkpoint
final_checkpoint = train_config.output_dir / f"{train_config.run_name}_final.pt"
save_checkpoint(
path=final_checkpoint,
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
step=step,
loss=loss.item(),
config=model_config.__dict__,
final=True,
)
print(f"Training complete. Final checkpoint: {final_checkpoint}")
See also
- Configuration - Configure training runs and output directories
- Training Scripts - Full training pipeline with checkpointing
- Evaluation - Load checkpoints for evaluation