Skip to main content
MaxDiffusion uses Orbax as the default checkpointing system to save and restore model states during training and inference.

Overview

Orbax is a flexible, high-performance checkpointing library designed for JAX applications. It replaced the legacy pipeline.save_pretrained method as the default in MaxDiffusion.
Orbax became the default checkpointer on August 1, 2024. You can still use pipeline.save_pretrained after training to save models in diffusers format.

Configuration

Checkpointing is controlled by two main parameters in your config file:
checkpoint_every: -1
enable_single_replica_ckpt_restoring: False

checkpoint_every

  • Type: integer
  • Default: -1 (disabled)
  • Description: Save a checkpoint every N training samples
Set to a positive integer to enable periodic checkpointing:
# Save checkpoint every 1000 samples
checkpoint_every: 1000

# Disable checkpointing
checkpoint_every: -1

enable_single_replica_ckpt_restoring

  • Type: boolean
  • Default: False
  • Description: Enables optimized checkpoint restoration where one replica reads the checkpoint and broadcasts to others
This can significantly speed up checkpoint loading in multi-host setups:
# Enable single-replica restoration
enable_single_replica_ckpt_restoring: True

Checkpoint structure

Orbax organizes checkpoints by model type, storing different components separately.

Stable Diffusion checkpoints

output_dir/
└── run_name/
    └── checkpoints/
        └── step_1000/
            ├── unet_state/
            ├── unet_config/
            ├── vae_state/
            ├── vae_config/
            ├── text_encoder_state/
            ├── text_encoder_config/
            ├── scheduler_config/
            └── tokenizer_config/

SDXL checkpoints

SDXL includes an additional text encoder:
output_dir/
└── run_name/
    └── checkpoints/
        └── step_1000/
            ├── ...(same as SD)
            ├── text_encoder_2_state/
            └── text_encoder_2_config/

Flux checkpoints

output_dir/
└── run_name/
    └── checkpoints/
        └── step_1000/
            ├── flux_state/
            ├── flux_config/
            ├── vae_state/
            ├── vae_config/
            ├── scheduler/
            ├── scheduler_config/
            ├── text_encoder_2_state/
            └── text_encoder_2_config/

Wan checkpoints

output_dir/
└── run_name/
    └── checkpoints/
        └── step_1000/
            ├── low_noise_transformer_state/
            ├── high_noise_transformer_state/
            ├── wan_state/
            └── wan_config/

Training with checkpoints

SDXL training example

python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_xl_run" \
  output_dir="gs://your-bucket/" \
  per_device_batch_size=1 \
  checkpoint_every=500 \
  max_train_steps=2000
This will create checkpoints at steps 500, 1000, 1500, and 2000.

Flux training example

python src/maxdiffusion/train_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  run_name="test-flux-train" \
  output_dir="gs://your-bucket/" \
  checkpoint_every=100 \
  save_final_checkpoint=True \
  jax_cache_dir="/tmp/jax_cache"

Restoring from checkpoints

Orbax automatically detects and restores from the latest checkpoint when you resume training.

Automatic restoration

Simply run the training command again with the same run_name and output_dir:
python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_xl_run" \
  output_dir="gs://your-bucket/" \
  checkpoint_every=500
MaxDiffusion will:
  1. Check for existing checkpoints in the output directory
  2. Find the latest checkpoint step
  3. Restore model state and optimizer state
  4. Continue training from that step

Generate with checkpoints

python -m src.maxdiffusion.generate \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_run" \
  pretrained_model_name_or_path=gs://your-bucket/my_xl_run/checkpoints/step_2000 \
  from_pt=False \
  attention=dot_product

Advanced features

Async checkpointing

Orbax supports asynchronous checkpointing to avoid blocking training:
mngr = CheckpointManager(
    checkpoint_dir,
    options=CheckpointManagerOptions(
        create=True,
        save_interval_steps=checkpoint_every,
        enable_async_checkpointing=True  # Enable async
    )
)
This is enabled by default in MaxDiffusion.

Single-replica restoration

For large multi-host setups, enable single-replica restoration to reduce checkpoint loading time:
python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_xl_run" \
  output_dir="gs://your-bucket/" \
  enable_single_replica_ckpt_restoring=True
How it works:
  1. One replica reads the checkpoint from storage
  2. The checkpoint is broadcast to all other replicas
  3. Reduces network bandwidth and storage reads

Grain dataset checkpointing

When using Grain datasets, iterator state is also checkpointed:
dataset_type: 'grain'
checkpoint_every: 500
The checkpoint includes an iter item that stores the dataset iterator position.

Best practices

Checkpoint frequency

Balance checkpoint frequency against storage and performance:
# Frequent (good for debugging, expensive storage)
checkpoint_every: 100

# Moderate (good balance)
checkpoint_every: 500

# Infrequent (saves storage, less recovery points)
checkpoint_every: 2000

Storage considerations

  • Use Google Cloud Storage (gs://) for multi-host training
  • Local storage (/tmp/ or local paths) works for single-host
  • Checkpoints can be large (multiple GB for SDXL/Flux)
  • Consider retention policies to manage storage costs

Multi-host training

For multi-host training:
  1. Always use GCS for checkpoint storage
  2. Enable single-replica restoration for faster loading
  3. Use appropriate checkpoint frequency based on step time
python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="multihost_run" \
  output_dir="gs://your-bucket/" \
  checkpoint_every=1000 \
  enable_single_replica_ckpt_restoring=True

Converting to diffusers format

After training with Orbax, you can convert checkpoints to diffusers format:
# In your training script or separately
pipeline.save_pretrained("path/to/diffusers/format")
This creates a checkpoint compatible with HuggingFace diffusers library.

Troubleshooting

Checkpoint not found

If restoration fails:
  1. Verify the output_dir and run_name match the training run
  2. Check that checkpoint directories exist in GCS/local storage
  3. Ensure proper permissions to read from the checkpoint location

Out of memory during restoration

For large models:
  1. Enable enable_single_replica_ckpt_restoring=True
  2. Ensure sufficient host memory
  3. Consider using a smaller batch size during initial loading

Mismatched shapes

If checkpoint shapes don’t match:
  1. Verify the model configuration hasn’t changed
  2. Check that parallelism settings match (ici_fsdp_parallelism, etc.)
  3. Ensure you’re loading the correct checkpoint type for your model

Build docs developers (and LLMs) love