Overview
Orbax is a flexible, high-performance checkpointing library designed for JAX applications. It replaced the legacypipeline.save_pretrained method as the default in MaxDiffusion.
Orbax became the default checkpointer on August 1, 2024. You can still use
pipeline.save_pretrained after training to save models in diffusers format.Configuration
Checkpointing is controlled by two main parameters in your config file:checkpoint_every
- Type: integer
- Default:
-1(disabled) - Description: Save a checkpoint every N training samples
enable_single_replica_ckpt_restoring
- Type: boolean
- Default:
False - Description: Enables optimized checkpoint restoration where one replica reads the checkpoint and broadcasts to others
Checkpoint structure
Orbax organizes checkpoints by model type, storing different components separately.Stable Diffusion checkpoints
SDXL checkpoints
SDXL includes an additional text encoder:Flux checkpoints
Wan checkpoints
Training with checkpoints
SDXL training example
Flux training example
Restoring from checkpoints
Orbax automatically detects and restores from the latest checkpoint when you resume training.Automatic restoration
Simply run the training command again with the samerun_name and output_dir:
- Check for existing checkpoints in the output directory
- Find the latest checkpoint step
- Restore model state and optimizer state
- Continue training from that step
Generate with checkpoints
Advanced features
Async checkpointing
Orbax supports asynchronous checkpointing to avoid blocking training:Single-replica restoration
For large multi-host setups, enable single-replica restoration to reduce checkpoint loading time:- One replica reads the checkpoint from storage
- The checkpoint is broadcast to all other replicas
- Reduces network bandwidth and storage reads
Grain dataset checkpointing
When using Grain datasets, iterator state is also checkpointed:iter item that stores the dataset iterator position.
Best practices
Checkpoint frequency
Balance checkpoint frequency against storage and performance:Storage considerations
- Use Google Cloud Storage (
gs://) for multi-host training - Local storage (
/tmp/or local paths) works for single-host - Checkpoints can be large (multiple GB for SDXL/Flux)
- Consider retention policies to manage storage costs
Multi-host training
For multi-host training:- Always use GCS for checkpoint storage
- Enable single-replica restoration for faster loading
- Use appropriate checkpoint frequency based on step time
Converting to diffusers format
After training with Orbax, you can convert checkpoints to diffusers format:Troubleshooting
Checkpoint not found
If restoration fails:- Verify the
output_dirandrun_namematch the training run - Check that checkpoint directories exist in GCS/local storage
- Ensure proper permissions to read from the checkpoint location
Out of memory during restoration
For large models:- Enable
enable_single_replica_ckpt_restoring=True - Ensure sufficient host memory
- Consider using a smaller batch size during initial loading
Mismatched shapes
If checkpoint shapes don’t match:- Verify the model configuration hasn’t changed
- Check that parallelism settings match (ici_fsdp_parallelism, etc.)
- Ensure you’re loading the correct checkpoint type for your model
Related resources
- Profiling - Monitor checkpoint save/load performance
- Quantization - Reduce checkpoint size