Skip to main content
Flyte’s workflow model provides implicit checkpoints at task boundaries — if one task fails, only that task retries. However, when a single task runs a long computation (such as a training loop over many epochs), you may want checkpoints within the task itself. Intratask checkpointing lets you save progress as a file during task execution and resume from the latest checkpoint on retry, rather than restarting from scratch.

Why intratask checkpoints?

Spot and preemptible instances

AWS spot and GCP preemptible instances can be reclaimed at any time. Checkpointing makes long tasks resilient to interruption at a fraction of the cost.

Long training loops

ML model training over many epochs or large datasets is expensive to restart. Checkpoint each epoch to resume from the last completed one.

Avoid task fan-out overhead

Breaking a loop into individual tasks via dynamic workflows adds overhead per iteration. A single checkpointed task avoids that cost.

Tight computation loops

Some computations are logically a single unit but run long enough to need fault tolerance. Checkpointing bridges both requirements.

How it works

Flytekit exposes a checkpoint API through the execution context:
ctx = flytekit.current_context()
cp = ctx.checkpoint
The checkpoint object provides:
  • cp.write(bytes) — save bytes to a checkpoint file
  • cp.read() — read the latest checkpoint bytes (returns None if no checkpoint exists)
On retry, cp.read() returns the bytes written during the previous attempt, allowing the task to resume from where it left off.

Example: checkpointed iteration task

Import required libraries and configure retries:
import flytekit
from flytekit import task, workflow
from flytekit.exceptions.user import FlyteRecoverableException
Define a task that iterates n_iterations times, checkpoints its progress, and can recover from simulated failures:
@task(retries=3)
def use_checkpoint(n_iterations: int) -> int:
    ctx = flytekit.current_context()
    cp = ctx.checkpoint

    # Try to restore state from a previous attempt
    start = 0
    prev = cp.read()
    if prev:
        start = int(prev.decode())

    # Simulate a failure on the first attempt at iteration 5
    for i in range(start, n_iterations):
        if i == 5 and start == 0:
            raise FlyteRecoverableException(f"Simulated failure at iteration {i}")
        # Save progress after each iteration
        cp.write(str(i + 1).encode())

    return n_iterations
Only FlyteRecoverableException triggers a retry. Other exceptions cause the task to fail immediately without retry.
Create a workflow that calls the task:
@workflow
def checkpointing_example(n_iterations: int = 10) -> int:
    return use_checkpoint(n_iterations=n_iterations)
Local execution note:
if __name__ == "__main__":
    # Local execution does not support retries, so the checkpoint behavior
    # is not exercised here. Run on a Flyte cluster to see retries and
    # checkpoint recovery in action.
    print(checkpointing_example(n_iterations=10))

Execution flow on the Flyte cluster

1

First attempt

The task starts, checkpoints progress through iterations 0–4, then raises FlyteRecoverableException at iteration 5.
2

Retry

Flyte retries the task. On startup, cp.read() returns "5" (the last checkpoint), so the task resumes from iteration 5 rather than 0.
3

Completion

The task completes all remaining iterations and returns n_iterations.

Using interruptible tasks with checkpoints

Mark a task as interruptible=True to allow Flyte to schedule it on spot or preemptible instances. Combined with checkpointing, this gives you fault-tolerant execution at reduced cost:
@task(retries=3, interruptible=True)
def training_task(n_epochs: int) -> float:
    ctx = flytekit.current_context()
    cp = ctx.checkpoint

    start_epoch = 0
    prev = cp.read()
    if prev:
        # Restore: prev stores "epoch:loss" as bytes
        parts = prev.decode().split(":")
        start_epoch = int(parts[0]) + 1

    best_loss = float("inf")
    for epoch in range(start_epoch, n_epochs):
        # ... training logic ...
        loss = 1.0 / (epoch + 1)  # placeholder
        best_loss = min(best_loss, loss)
        cp.write(f"{epoch}:{best_loss}".encode())

    return best_loss
For tasks running under 10 minutes, task-boundary retries usually provide sufficient fault tolerance. Intratask checkpointing pays off most for tasks running 30+ minutes on spot instances.

Additional checkpoint APIs

The checkpoint system provides additional APIs. See the full reference in the checkpointer source code. Future Flyte versions aim to integrate higher-level checkpointing APIs from popular frameworks like PyTorch, Keras, Scikit-learn, Spark, and Flink.

Run on the Flyte cluster

pyflyte run --remote \
  https://raw.githubusercontent.com/flyteorg/flytesnacks/69dbe4840031a85d79d9ded25f80397c6834752d/examples/advanced_composition/advanced_composition/checkpoint.py \
  checkpointing_example --n_iterations 10

Build docs developers (and LLMs) love