Save and restore task state mid-execution to enable fault-tolerant long-running tasks on spot instances.
Flyte’s workflow model provides implicit checkpoints at task boundaries — if one task fails, only that task retries. However, when a single task runs a long computation (such as a training loop over many epochs), you may want checkpoints within the task itself.Intratask checkpointing lets you save progress as a file during task execution and resume from the latest checkpoint on retry, rather than restarting from scratch.
AWS spot and GCP preemptible instances can be reclaimed at any time. Checkpointing makes long tasks resilient to interruption at a fraction of the cost.
Long training loops
ML model training over many epochs or large datasets is expensive to restart. Checkpoint each epoch to resume from the last completed one.
Avoid task fan-out overhead
Breaking a loop into individual tasks via dynamic workflows adds overhead per iteration. A single checkpointed task avoids that cost.
Tight computation loops
Some computations are logically a single unit but run long enough to need fault tolerance. Checkpointing bridges both requirements.
Define a task that iterates n_iterations times, checkpoints its progress, and can recover from simulated failures:
@task(retries=3)def use_checkpoint(n_iterations: int) -> int: ctx = flytekit.current_context() cp = ctx.checkpoint # Try to restore state from a previous attempt start = 0 prev = cp.read() if prev: start = int(prev.decode()) # Simulate a failure on the first attempt at iteration 5 for i in range(start, n_iterations): if i == 5 and start == 0: raise FlyteRecoverableException(f"Simulated failure at iteration {i}") # Save progress after each iteration cp.write(str(i + 1).encode()) return n_iterations
Only FlyteRecoverableException triggers a retry. Other exceptions cause the task to fail immediately without retry.
Create a workflow that calls the task:
@workflowdef checkpointing_example(n_iterations: int = 10) -> int: return use_checkpoint(n_iterations=n_iterations)
Local execution note:
if __name__ == "__main__": # Local execution does not support retries, so the checkpoint behavior # is not exercised here. Run on a Flyte cluster to see retries and # checkpoint recovery in action. print(checkpointing_example(n_iterations=10))
Mark a task as interruptible=True to allow Flyte to schedule it on spot or preemptible instances. Combined with checkpointing, this gives you fault-tolerant execution at reduced cost:
@task(retries=3, interruptible=True)def training_task(n_epochs: int) -> float: ctx = flytekit.current_context() cp = ctx.checkpoint start_epoch = 0 prev = cp.read() if prev: # Restore: prev stores "epoch:loss" as bytes parts = prev.decode().split(":") start_epoch = int(parts[0]) + 1 best_loss = float("inf") for epoch in range(start_epoch, n_epochs): # ... training logic ... loss = 1.0 / (epoch + 1) # placeholder best_loss = min(best_loss, loss) cp.write(f"{epoch}:{best_loss}".encode()) return best_loss
For tasks running under 10 minutes, task-boundary retries usually provide sufficient fault tolerance. Intratask checkpointing pays off most for tasks running 30+ minutes on spot instances.
The checkpoint system provides additional APIs. See the full reference in the checkpointer source code.Future Flyte versions aim to integrate higher-level checkpointing APIs from popular frameworks like PyTorch, Keras, Scikit-learn, Spark, and Flink.