Overview
Checkpointing allows you to save intermediate progress during long-running computations and resume from the last checkpoint if the process is interrupted. This is crucial for:- Long-running training jobs that may be interrupted
- Expensive computations that you don’t want to repeat
- Handling spot instance interruptions
- Iterative algorithms that benefit from incremental progress
Metaflow automatically checkpoints data artifacts between steps. This guide focuses on checkpointing within steps.
How Metaflow Handles Interruptions
Automatic Step-Level Checkpointing
Metaflow automatically saves all artifacts at the end of each step:from metaflow import FlowSpec, step
class AutoCheckpoint(FlowSpec):
@step
def start(self):
self.data = expensive_computation()
# data is automatically checkpointed here
self.next(self.process)
@step
def process(self):
# If this fails, start doesn't need to rerun
self.result = process(self.data)
self.next(self.end)
@step
def end(self):
pass
process fails, you can resume without rerunning start:
# Resume from the last successful step
python flow.py resume
Within-Step Checkpointing
For long computations within a single step, use manual checkpointing:import os
import pickle
from metaflow import FlowSpec, step, current
class ManualCheckpoint(FlowSpec):
@step
def train(self):
checkpoint_file = f"checkpoint_{current.run_id}.pkl"
# Check if checkpoint exists
if os.path.exists(checkpoint_file):
print("Resuming from checkpoint...")
with open(checkpoint_file, 'rb') as f:
checkpoint = pickle.load(f)
start_epoch = checkpoint['epoch']
model = checkpoint['model']
else:
print("Starting from scratch...")
start_epoch = 0
model = initialize_model()
# Train for remaining epochs
for epoch in range(start_epoch, 100):
model.train_epoch()
# Save checkpoint every 10 epochs
if epoch % 10 == 0:
checkpoint = {
'epoch': epoch + 1,
'model': model
}
with open(checkpoint_file, 'wb') as f:
pickle.dump(checkpoint, f)
print(f"Checkpoint saved at epoch {epoch}")
# Save final model
self.model = model
# Clean up checkpoint
if os.path.exists(checkpoint_file):
os.remove(checkpoint_file)
self.next(self.end)
@step
def end(self):
pass
Checkpointing to S3
For cloud execution, save checkpoints to S3:from metaflow import FlowSpec, step, S3, current
import pickle
class S3Checkpoint(FlowSpec):
@step
def train(self):
with S3(run=self) as s3:
checkpoint_key = f"checkpoints/{current.pathspec}/checkpoint.pkl"
# Try to load checkpoint
try:
checkpoint_obj = s3.get(checkpoint_key)
checkpoint = pickle.loads(checkpoint_obj.blob)
print(f"Resuming from epoch {checkpoint['epoch']}")
start_epoch = checkpoint['epoch']
model = checkpoint['model']
except:
print("No checkpoint found, starting fresh")
start_epoch = 0
model = initialize_model()
# Training loop
for epoch in range(start_epoch, 100):
model.train_epoch()
# Checkpoint every 10 epochs
if epoch % 10 == 0:
checkpoint = {
'epoch': epoch + 1,
'model': model,
'metrics': model.get_metrics()
}
s3.put(checkpoint_key, pickle.dumps(checkpoint))
print(f"Checkpoint saved to S3 at epoch {epoch}")
# Save final model
self.model = model
self.next(self.end)
@step
def end(self):
pass
PyTorch Checkpointing
Basic PyTorch Checkpoint
import torch
from metaflow import FlowSpec, step, S3
class PyTorchCheckpoint(FlowSpec):
@step
def train(self):
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
with S3(run=self) as s3:
checkpoint_key = "checkpoint.pt"
# Load checkpoint if exists
try:
checkpoint_blob = s3.get(checkpoint_key).blob
checkpoint = torch.load(io.BytesIO(checkpoint_blob))
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
start_epoch = checkpoint['epoch']
print(f"Resumed from epoch {start_epoch}")
except:
start_epoch = 0
print("Starting training from scratch")
# Training loop
for epoch in range(start_epoch, 100):
# Train one epoch
for batch in dataloader:
loss = train_step(model, batch, optimizer)
# Save checkpoint
if epoch % 10 == 0:
checkpoint = {
'epoch': epoch + 1,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'loss': loss
}
buffer = io.BytesIO()
torch.save(checkpoint, buffer)
buffer.seek(0)
s3.put(checkpoint_key, buffer.read())
print(f"Checkpoint saved at epoch {epoch}")
self.model_state = model.state_dict()
self.next(self.end)
@step
def end(self):
pass
Distributed Training Checkpoint
import torch
import torch.distributed as dist
from metaflow import FlowSpec, step, parallel, current
class DistributedCheckpoint(FlowSpec):
@parallel(cpu=8, gpu=4)
@step
def train(self):
# Initialize distributed training
rank = current.parallel.node_index
world_size = current.parallel.num_nodes
dist.init_process_group(
backend='nccl',
rank=rank,
world_size=world_size
)
model = MyModel().to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
with S3(run=self) as s3:
checkpoint_key = f"checkpoint_rank_{rank}.pt"
# Only rank 0 handles checkpointing
if rank == 0:
try:
checkpoint_blob = s3.get(checkpoint_key).blob
checkpoint = torch.load(io.BytesIO(checkpoint_blob))
model.load_state_dict(checkpoint['model_state'])
start_epoch = checkpoint['epoch']
except:
start_epoch = 0
else:
start_epoch = 0
# Broadcast start epoch to all ranks
start_epoch = torch.tensor(start_epoch).to(rank)
dist.broadcast(start_epoch, src=0)
start_epoch = start_epoch.item()
# Training loop
for epoch in range(start_epoch, 100):
train_one_epoch(model, rank)
# Checkpoint from rank 0
if rank == 0 and epoch % 10 == 0:
checkpoint = {
'epoch': epoch + 1,
'model_state': model.state_dict()
}
buffer = io.BytesIO()
torch.save(checkpoint, buffer)
buffer.seek(0)
s3.put(checkpoint_key, buffer.read())
if rank == 0:
self.model_state = model.state_dict()
self.next(self.end)
@step
def end(self, inputs):
# Use model from rank 0
self.model_state = inputs[0].model_state
TensorFlow Checkpointing
import tensorflow as tf
from metaflow import FlowSpec, step, S3
import tempfile
import os
class TensorFlowCheckpoint(FlowSpec):
@step
def train(self):
model = create_model()
# Create checkpoint callback
checkpoint_dir = tempfile.mkdtemp()
checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
save_freq='epoch'
)
with S3(run=self) as s3:
# Try to restore from S3
try:
checkpoint_blob = s3.get('checkpoint.h5').blob
with open(checkpoint_path + '.h5', 'wb') as f:
f.write(checkpoint_blob)
model.load_weights(checkpoint_path + '.h5')
print("Restored from checkpoint")
except:
print("No checkpoint found")
# Train model
history = model.fit(
train_dataset,
epochs=100,
callbacks=[checkpoint_callback]
)
# Save final checkpoint to S3
model.save_weights(checkpoint_path + '_final.h5')
with open(checkpoint_path + '_final.h5', 'rb') as f:
s3.put('checkpoint.h5', f.read())
self.model = model
self.next(self.end)
@step
def end(self):
pass
Handling Spot Instance Interruptions
import signal
import sys
from metaflow import FlowSpec, step, S3, current
class SpotInstanceFlow(FlowSpec):
@step
def train(self):
# Set up signal handler for spot interruption
def signal_handler(signum, frame):
print("Spot instance interruption detected!")
print("Saving checkpoint...")
save_checkpoint(model, epoch)
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
model = initialize_model()
with S3(run=self) as s3:
# Load checkpoint if exists
try:
checkpoint = load_checkpoint(s3)
model = checkpoint['model']
start_epoch = checkpoint['epoch']
except:
start_epoch = 0
# Training loop
for epoch in range(start_epoch, 100):
try:
model.train_epoch()
# Checkpoint frequently
if epoch % 5 == 0:
save_checkpoint(s3, model, epoch)
except KeyboardInterrupt:
print("Interrupted! Saving checkpoint...")
save_checkpoint(s3, model, epoch)
raise
self.model = model
self.next(self.end)
@step
def end(self):
pass
Best Practices
Checkpoint frequently for long jobs
Checkpoint frequently for long jobs
For jobs that run for hours, checkpoint every 10-30 minutes:
import time
last_checkpoint = time.time()
for iteration in range(1000000):
# Do work
# Checkpoint every 30 minutes
if time.time() - last_checkpoint > 1800:
save_checkpoint()
last_checkpoint = time.time()
Use S3 for cloud execution
Use S3 for cloud execution
Always save checkpoints to S3 when running on cloud compute:
@batch(cpu=16, memory=32000)
@step
def train(self):
with S3(run=self) as s3:
# Checkpoint to S3
s3.put('checkpoint.pkl', checkpoint_data)
Save optimizer state
Save optimizer state
For ML training, always save optimizer state:
checkpoint = {
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'loss': loss
}
Clean up old checkpoints
Clean up old checkpoints
Remove intermediate checkpoints when done:
# Keep only the last N checkpoints
checkpoints = sorted(glob.glob('checkpoint_*.pt'))
for old_checkpoint in checkpoints[:-5]:
os.remove(old_checkpoint)
Test checkpoint recovery
Test checkpoint recovery
Always test that your checkpoint recovery works:
# Save checkpoint
save_checkpoint(model, epoch=5)
# Simulate restart
model = initialize_model()
checkpoint = load_checkpoint()
model.load_state_dict(checkpoint['model_state'])
# Verify epoch counter
assert checkpoint['epoch'] == 5
Related Topics
Retry Decorator
Automatic retry on failure
AWS Batch
Running on AWS Batch with spot instances
S3 Data Tools
Using S3 for data storage
Error Handling
Handling errors in flows
