Skip to main content

Overview

Checkpointing allows you to save intermediate progress during long-running computations and resume from the last checkpoint if the process is interrupted. This is crucial for:
  • Long-running training jobs that may be interrupted
  • Expensive computations that you don’t want to repeat
  • Handling spot instance interruptions
  • Iterative algorithms that benefit from incremental progress
Metaflow automatically checkpoints data artifacts between steps. This guide focuses on checkpointing within steps.

How Metaflow Handles Interruptions

Automatic Step-Level Checkpointing

Metaflow automatically saves all artifacts at the end of each step:
from metaflow import FlowSpec, step

class AutoCheckpoint(FlowSpec):
    
    @step
    def start(self):
        self.data = expensive_computation()
        # data is automatically checkpointed here
        self.next(self.process)
    
    @step
    def process(self):
        # If this fails, start doesn't need to rerun
        self.result = process(self.data)
        self.next(self.end)
    
    @step
    def end(self):
        pass
If process fails, you can resume without rerunning start:
# Resume from the last successful step
python flow.py resume

Within-Step Checkpointing

For long computations within a single step, use manual checkpointing:
import os
import pickle
from metaflow import FlowSpec, step, current

class ManualCheckpoint(FlowSpec):
    
    @step
    def train(self):
        checkpoint_file = f"checkpoint_{current.run_id}.pkl"
        
        # Check if checkpoint exists
        if os.path.exists(checkpoint_file):
            print("Resuming from checkpoint...")
            with open(checkpoint_file, 'rb') as f:
                checkpoint = pickle.load(f)
            start_epoch = checkpoint['epoch']
            model = checkpoint['model']
        else:
            print("Starting from scratch...")
            start_epoch = 0
            model = initialize_model()
        
        # Train for remaining epochs
        for epoch in range(start_epoch, 100):
            model.train_epoch()
            
            # Save checkpoint every 10 epochs
            if epoch % 10 == 0:
                checkpoint = {
                    'epoch': epoch + 1,
                    'model': model
                }
                with open(checkpoint_file, 'wb') as f:
                    pickle.dump(checkpoint, f)
                print(f"Checkpoint saved at epoch {epoch}")
        
        # Save final model
        self.model = model
        
        # Clean up checkpoint
        if os.path.exists(checkpoint_file):
            os.remove(checkpoint_file)
        
        self.next(self.end)
    
    @step
    def end(self):
        pass

Checkpointing to S3

For cloud execution, save checkpoints to S3:
from metaflow import FlowSpec, step, S3, current
import pickle

class S3Checkpoint(FlowSpec):
    
    @step
    def train(self):
        with S3(run=self) as s3:
            checkpoint_key = f"checkpoints/{current.pathspec}/checkpoint.pkl"
            
            # Try to load checkpoint
            try:
                checkpoint_obj = s3.get(checkpoint_key)
                checkpoint = pickle.loads(checkpoint_obj.blob)
                print(f"Resuming from epoch {checkpoint['epoch']}")
                start_epoch = checkpoint['epoch']
                model = checkpoint['model']
            except:
                print("No checkpoint found, starting fresh")
                start_epoch = 0
                model = initialize_model()
            
            # Training loop
            for epoch in range(start_epoch, 100):
                model.train_epoch()
                
                # Checkpoint every 10 epochs
                if epoch % 10 == 0:
                    checkpoint = {
                        'epoch': epoch + 1,
                        'model': model,
                        'metrics': model.get_metrics()
                    }
                    s3.put(checkpoint_key, pickle.dumps(checkpoint))
                    print(f"Checkpoint saved to S3 at epoch {epoch}")
            
            # Save final model
            self.model = model
        
        self.next(self.end)
    
    @step
    def end(self):
        pass

PyTorch Checkpointing

Basic PyTorch Checkpoint

import torch
from metaflow import FlowSpec, step, S3

class PyTorchCheckpoint(FlowSpec):
    
    @step
    def train(self):
        model = MyModel()
        optimizer = torch.optim.Adam(model.parameters())
        
        with S3(run=self) as s3:
            checkpoint_key = "checkpoint.pt"
            
            # Load checkpoint if exists
            try:
                checkpoint_blob = s3.get(checkpoint_key).blob
                checkpoint = torch.load(io.BytesIO(checkpoint_blob))
                model.load_state_dict(checkpoint['model_state'])
                optimizer.load_state_dict(checkpoint['optimizer_state'])
                start_epoch = checkpoint['epoch']
                print(f"Resumed from epoch {start_epoch}")
            except:
                start_epoch = 0
                print("Starting training from scratch")
            
            # Training loop
            for epoch in range(start_epoch, 100):
                # Train one epoch
                for batch in dataloader:
                    loss = train_step(model, batch, optimizer)
                
                # Save checkpoint
                if epoch % 10 == 0:
                    checkpoint = {
                        'epoch': epoch + 1,
                        'model_state': model.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'loss': loss
                    }
                    
                    buffer = io.BytesIO()
                    torch.save(checkpoint, buffer)
                    buffer.seek(0)
                    s3.put(checkpoint_key, buffer.read())
                    print(f"Checkpoint saved at epoch {epoch}")
            
            self.model_state = model.state_dict()
        
        self.next(self.end)
    
    @step
    def end(self):
        pass

Distributed Training Checkpoint

import torch
import torch.distributed as dist
from metaflow import FlowSpec, step, parallel, current

class DistributedCheckpoint(FlowSpec):
    
    @parallel(cpu=8, gpu=4)
    @step
    def train(self):
        # Initialize distributed training
        rank = current.parallel.node_index
        world_size = current.parallel.num_nodes
        
        dist.init_process_group(
            backend='nccl',
            rank=rank,
            world_size=world_size
        )
        
        model = MyModel().to(rank)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
        
        with S3(run=self) as s3:
            checkpoint_key = f"checkpoint_rank_{rank}.pt"
            
            # Only rank 0 handles checkpointing
            if rank == 0:
                try:
                    checkpoint_blob = s3.get(checkpoint_key).blob
                    checkpoint = torch.load(io.BytesIO(checkpoint_blob))
                    model.load_state_dict(checkpoint['model_state'])
                    start_epoch = checkpoint['epoch']
                except:
                    start_epoch = 0
            else:
                start_epoch = 0
            
            # Broadcast start epoch to all ranks
            start_epoch = torch.tensor(start_epoch).to(rank)
            dist.broadcast(start_epoch, src=0)
            start_epoch = start_epoch.item()
            
            # Training loop
            for epoch in range(start_epoch, 100):
                train_one_epoch(model, rank)
                
                # Checkpoint from rank 0
                if rank == 0 and epoch % 10 == 0:
                    checkpoint = {
                        'epoch': epoch + 1,
                        'model_state': model.state_dict()
                    }
                    buffer = io.BytesIO()
                    torch.save(checkpoint, buffer)
                    buffer.seek(0)
                    s3.put(checkpoint_key, buffer.read())
        
        if rank == 0:
            self.model_state = model.state_dict()
        
        self.next(self.end)
    
    @step
    def end(self, inputs):
        # Use model from rank 0
        self.model_state = inputs[0].model_state

TensorFlow Checkpointing

import tensorflow as tf
from metaflow import FlowSpec, step, S3
import tempfile
import os

class TensorFlowCheckpoint(FlowSpec):
    
    @step
    def train(self):
        model = create_model()
        
        # Create checkpoint callback
        checkpoint_dir = tempfile.mkdtemp()
        checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint')
        
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_path,
            save_weights_only=True,
            save_freq='epoch'
        )
        
        with S3(run=self) as s3:
            # Try to restore from S3
            try:
                checkpoint_blob = s3.get('checkpoint.h5').blob
                with open(checkpoint_path + '.h5', 'wb') as f:
                    f.write(checkpoint_blob)
                model.load_weights(checkpoint_path + '.h5')
                print("Restored from checkpoint")
            except:
                print("No checkpoint found")
            
            # Train model
            history = model.fit(
                train_dataset,
                epochs=100,
                callbacks=[checkpoint_callback]
            )
            
            # Save final checkpoint to S3
            model.save_weights(checkpoint_path + '_final.h5')
            with open(checkpoint_path + '_final.h5', 'rb') as f:
                s3.put('checkpoint.h5', f.read())
        
        self.model = model
        self.next(self.end)
    
    @step
    def end(self):
        pass

Handling Spot Instance Interruptions

import signal
import sys
from metaflow import FlowSpec, step, S3, current

class SpotInstanceFlow(FlowSpec):
    
    @step
    def train(self):
        # Set up signal handler for spot interruption
        def signal_handler(signum, frame):
            print("Spot instance interruption detected!")
            print("Saving checkpoint...")
            save_checkpoint(model, epoch)
            sys.exit(0)
        
        signal.signal(signal.SIGTERM, signal_handler)
        
        model = initialize_model()
        
        with S3(run=self) as s3:
            # Load checkpoint if exists
            try:
                checkpoint = load_checkpoint(s3)
                model = checkpoint['model']
                start_epoch = checkpoint['epoch']
            except:
                start_epoch = 0
            
            # Training loop
            for epoch in range(start_epoch, 100):
                try:
                    model.train_epoch()
                    
                    # Checkpoint frequently
                    if epoch % 5 == 0:
                        save_checkpoint(s3, model, epoch)
                        
                except KeyboardInterrupt:
                    print("Interrupted! Saving checkpoint...")
                    save_checkpoint(s3, model, epoch)
                    raise
            
            self.model = model
        
        self.next(self.end)
    
    @step
    def end(self):
        pass

Best Practices

For jobs that run for hours, checkpoint every 10-30 minutes:
import time
last_checkpoint = time.time()

for iteration in range(1000000):
    # Do work
    
    # Checkpoint every 30 minutes
    if time.time() - last_checkpoint > 1800:
        save_checkpoint()
        last_checkpoint = time.time()
Always save checkpoints to S3 when running on cloud compute:
@batch(cpu=16, memory=32000)
@step
def train(self):
    with S3(run=self) as s3:
        # Checkpoint to S3
        s3.put('checkpoint.pkl', checkpoint_data)
For ML training, always save optimizer state:
checkpoint = {
    'epoch': epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'scheduler_state': scheduler.state_dict(),
    'loss': loss
}
Remove intermediate checkpoints when done:
# Keep only the last N checkpoints
checkpoints = sorted(glob.glob('checkpoint_*.pt'))
for old_checkpoint in checkpoints[:-5]:
    os.remove(old_checkpoint)
Always test that your checkpoint recovery works:
# Save checkpoint
save_checkpoint(model, epoch=5)

# Simulate restart
model = initialize_model()
checkpoint = load_checkpoint()
model.load_state_dict(checkpoint['model_state'])

# Verify epoch counter
assert checkpoint['epoch'] == 5

Retry Decorator

Automatic retry on failure

AWS Batch

Running on AWS Batch with spot instances

S3 Data Tools

Using S3 for data storage

Error Handling

Handling errors in flows

Build docs developers (and LLMs) love