Skip to main content

Overview

MaxDiffusion is a high-performance latent diffusion framework written in pure Python/JAX that runs on XLA devices including Cloud TPUs and GPUs. It provides reference implementations for various latent diffusion models, designed for both ambitious research projects and production deployments.

Design principles

MaxDiffusion is built on several key design principles: XLA-native execution: Written entirely in Python/JAX, MaxDiffusion compiles to XLA for optimal performance on TPUs and GPUs. This enables efficient execution across different hardware accelerators without platform-specific code. Scalable parallelism: The framework supports multiple parallelism strategies (data, FSDP, tensor, and context parallelism) to scale from single devices to large TPU pods with thousands of chips. Modular design: MaxDiffusion separates concerns into distinct components (models, attention mechanisms, schedulers, data pipelines) that can be customized independently. Memory efficiency: Features like gradient checkpointing, mixed precision training, and encoder offloading enable training and inference on large models with limited memory.

Core components

MaxDiffusion’s architecture consists of several interconnected components:

Model layers

The framework implements multiple diffusion model architectures:
  • Transformers: Flux and Wan models use transformer-based architectures with self-attention and cross-attention blocks
  • U-Nets: Stable Diffusion models (1.x, 2.x, XL) use convolutional U-Net architectures
  • VAE: All models use variational autoencoders for encoding/decoding between pixel and latent space
  • Text encoders: CLIP, T5, and other encoders convert text prompts to embeddings

Attention mechanisms

Attention is the computational bottleneck in diffusion models. MaxDiffusion provides multiple attention implementations optimized for different hardware:
  • Flash attention: TPU-optimized attention using Pallas kernels
  • Dot product attention: Standard attention with optional memory-efficient variants
  • cuDNN flash attention: GPU-optimized attention via Transformer Engine
  • Ring attention: Distributed attention for extremely long sequences
See the attention mechanisms page for details.

Sharding and parallelism

MaxDiffusion uses JAX’s sharding APIs to distribute computation across devices:
# Mesh defines the device grid
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# Logical axis rules map tensor dimensions to mesh axes
logical_axis_rules: [
  ['batch', ['data', 'fsdp']],
  ['activation_heads', 'tensor'],
  ['embed', ['context', 'fsdp']],
  ...
]
The mesh partitioning strategy is configured via YAML files and determines how model weights and activations are distributed. See parallelism for configuration details.

Data pipeline

MaxDiffusion supports multiple data formats and loading strategies:
  • TFRecords: Efficient format for large-scale training with caching support
  • HuggingFace datasets: Direct loading from the HuggingFace Hub
  • Grain: Google’s data loading library for high throughput
  • Synthetic data: For benchmarking and testing without real datasets
Data can be preprocessed to cache latents and text embeddings, reducing training time by avoiding redundant VAE/encoder passes.

Schedulers

Diffusion schedulers control the noise schedule during training and inference:
  • Euler Discrete: Default scheduler for most models
  • DDIM: Deterministic sampling for faster inference
  • Flow matching: Used in Flux and Wan models
Schedulers are configurable via the diffusion_scheduler_config parameter.

Memory optimization

MaxDiffusion provides several strategies to reduce memory usage:

Gradient checkpointing

The remat_policy parameter controls gradient checkpointing:
remat_policy: "NONE"  # No checkpointing (highest memory)

Mixed precision

Separate control over weight and activation dtypes:
weights_dtype: 'bfloat16'      # Model parameters
activations_dtype: 'bfloat16'  # Intermediate activations
precision: "DEFAULT"            # Matmul precision

Encoder offloading

For Flux models, text encoders can be offloaded after encoding:
offload_encoders: True  # Offload T5 after text encoding

Training workflow

A typical MaxDiffusion training workflow:
  1. Dataset preparation: Convert data to TFRecords or load from HuggingFace
  2. Configuration: Set model, parallelism, and training parameters in YAML
  3. Initialization: Load pretrained weights or initialize from scratch
  4. Training loop: Run training with automatic checkpointing and metrics
  5. Generation: Use trained checkpoints for inference
MaxDiffusion automatically handles multi-host coordination, checkpointing, and metrics logging to Google Cloud Storage.

Inference workflow

Inference can use different sharding strategies than training:
# Training: Data parallelism
ici_data_parallelism: 4
ici_fsdp_parallelism: 1

# Inference: Model parallelism
ici_data_parallelism: 1
ici_fsdp_parallelism: 4
This allows efficient batch inference with data parallelism or large model inference with FSDP.

Integration with HuggingFace

MaxDiffusion maintains compatibility with HuggingFace Diffusers:
  • Weight loading: Load pretrained PyTorch weights with from_pt: True
  • Model export: Save checkpoints in Diffusers format with pipeline.save_pretrained()
  • Configuration: Use Diffusers model configs for architecture parameters
This enables easy experimentation with pretrained models and seamless transitions between frameworks.

Performance optimization

MaxDiffusion includes several optimizations for maximum throughput:

XLA compiler flags

The LIBTPU_INIT_ARGS environment variable configures TPU-specific optimizations:
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion=true \
  --xla_tpu_overlap_compute_collective_tc=true \
  --xla_enable_async_all_gather=true'

JIT compilation

All model initialization and training steps are JIT-compiled for performance:
jit_initializers: True  # JIT compile model initialization
On multi-host setups, jit_initializers must be True for proper weight synchronization.

Scan for transformer layers

The scan_layers parameter uses jax.lax.scan to reduce compilation time for transformer models:
scan_layers: True  # Use scan for transformer layers

Next steps

Build docs developers (and LLMs) love