Overview
MaxDiffusion is a high-performance latent diffusion framework written in pure Python/JAX that runs on XLA devices including Cloud TPUs and GPUs. It provides reference implementations for various latent diffusion models, designed for both ambitious research projects and production deployments.Design principles
MaxDiffusion is built on several key design principles: XLA-native execution: Written entirely in Python/JAX, MaxDiffusion compiles to XLA for optimal performance on TPUs and GPUs. This enables efficient execution across different hardware accelerators without platform-specific code. Scalable parallelism: The framework supports multiple parallelism strategies (data, FSDP, tensor, and context parallelism) to scale from single devices to large TPU pods with thousands of chips. Modular design: MaxDiffusion separates concerns into distinct components (models, attention mechanisms, schedulers, data pipelines) that can be customized independently. Memory efficiency: Features like gradient checkpointing, mixed precision training, and encoder offloading enable training and inference on large models with limited memory.Core components
MaxDiffusion’s architecture consists of several interconnected components:Model layers
The framework implements multiple diffusion model architectures:- Transformers: Flux and Wan models use transformer-based architectures with self-attention and cross-attention blocks
- U-Nets: Stable Diffusion models (1.x, 2.x, XL) use convolutional U-Net architectures
- VAE: All models use variational autoencoders for encoding/decoding between pixel and latent space
- Text encoders: CLIP, T5, and other encoders convert text prompts to embeddings
Attention mechanisms
Attention is the computational bottleneck in diffusion models. MaxDiffusion provides multiple attention implementations optimized for different hardware:- Flash attention: TPU-optimized attention using Pallas kernels
- Dot product attention: Standard attention with optional memory-efficient variants
- cuDNN flash attention: GPU-optimized attention via Transformer Engine
- Ring attention: Distributed attention for extremely long sequences
Sharding and parallelism
MaxDiffusion uses JAX’s sharding APIs to distribute computation across devices:Data pipeline
MaxDiffusion supports multiple data formats and loading strategies:- TFRecords: Efficient format for large-scale training with caching support
- HuggingFace datasets: Direct loading from the HuggingFace Hub
- Grain: Google’s data loading library for high throughput
- Synthetic data: For benchmarking and testing without real datasets
Schedulers
Diffusion schedulers control the noise schedule during training and inference:- Euler Discrete: Default scheduler for most models
- DDIM: Deterministic sampling for faster inference
- Flow matching: Used in Flux and Wan models
diffusion_scheduler_config parameter.
Memory optimization
MaxDiffusion provides several strategies to reduce memory usage:Gradient checkpointing
Theremat_policy parameter controls gradient checkpointing:
Mixed precision
Separate control over weight and activation dtypes:Encoder offloading
For Flux models, text encoders can be offloaded after encoding:Training workflow
A typical MaxDiffusion training workflow:- Dataset preparation: Convert data to TFRecords or load from HuggingFace
- Configuration: Set model, parallelism, and training parameters in YAML
- Initialization: Load pretrained weights or initialize from scratch
- Training loop: Run training with automatic checkpointing and metrics
- Generation: Use trained checkpoints for inference
MaxDiffusion automatically handles multi-host coordination, checkpointing, and metrics logging to Google Cloud Storage.
Inference workflow
Inference can use different sharding strategies than training:Integration with HuggingFace
MaxDiffusion maintains compatibility with HuggingFace Diffusers:- Weight loading: Load pretrained PyTorch weights with
from_pt: True - Model export: Save checkpoints in Diffusers format with
pipeline.save_pretrained() - Configuration: Use Diffusers model configs for architecture parameters
Performance optimization
MaxDiffusion includes several optimizations for maximum throughput:XLA compiler flags
TheLIBTPU_INIT_ARGS environment variable configures TPU-specific optimizations:
JIT compilation
All model initialization and training steps are JIT-compiled for performance:On multi-host setups,
jit_initializers must be True for proper weight synchronization.Scan for transformer layers
Thescan_layers parameter uses jax.lax.scan to reduce compilation time for transformer models:
Next steps
- Explore supported models
- Learn about parallelism strategies
- Understand attention mechanisms