Overview
MaxDiffusion supports multiple parallelism strategies to scale training and inference across TPU pods and GPU clusters. The framework uses JAX’s sharding APIs to distribute model weights and activations across devices efficiently.Mesh axes
MaxDiffusion defines a 4-dimensional device mesh for sharding:- data: Data parallelism (replicate model, shard batch)
- fsdp: Fully Sharded Data Parallelism (shard model parameters)
- context: Context/sequence parallelism (shard sequence dimension)
- tensor: Tensor parallelism (shard attention heads and FFN)
Parallelism types
Data parallelism
How it works: Each device has a complete copy of the model and processes a different subset of the batch. Benefits:- Simple to implement and debug
- Excellent scaling for large batch sizes
- No communication during forward/backward pass
Data parallelism is the default strategy and recommended for most workloads when the model fits in device memory.
FSDP (Fully Sharded Data Parallelism)
How it works: Model parameters are sharded across devices, with each device owning a fraction of the weights. During computation, parameters are gathered as needed. Benefits:- Trains models larger than single-device memory
- Reduces memory per device
- Enables larger batch sizes
FSDP adds communication overhead during forward/backward passes. Use when model doesn’t fit in memory, not for speed optimization.
Context/sequence parallelism
How it works: The sequence dimension is sharded across devices. Each device processes a portion of the sequence in attention layers. Benefits:- Enables extremely long sequences (videos, high-resolution images)
- Reduces memory for attention QKV matrices
- Critical for video models like Wan
For Wan models, the
ici_fsdp_parallelism axis is used for sequence parallelism, while ici_tensor_parallelism is used for head parallelism.Tensor parallelism
How it works: Individual layers are split across devices. Attention heads and FFN dimensions are partitioned. Benefits:- Reduces memory for large layers
- Can combine with other strategies
- Useful for models with many attention heads
- Number of attention heads must be divisible by
ici_tensor_parallelism - For Wan2.1: 40 heads, so use 1, 2, 4, 5, 8, 10, 20, or 40
ICI vs DCN parallelism
MaxDiffusion distinguishes between two network types: ICI (Inter-Chip Interconnect):- High-bandwidth, low-latency network within a single TPU pod
- Use for compute-heavy parallelism (tensor, context)
- Lower bandwidth between TPU pods
- Use for data parallelism (less communication)
Configuration example
- Data parallelism across 4 pods × 2 devices = 8-way data parallel
- Context parallelism across 2 devices within each pod
- Tensor parallelism across 2 devices within each pod
Auto-sharding
Use-1 to automatically shard an axis:
ici_data_parallelism: 8.
Only one axis (ICI or DCN) can use auto-sharding. The product of all axes must equal the total number of devices.
Logical axis rules
Logical axis rules map tensor dimensions to mesh axes:How it works
- Tensor dimensions (left side) are named in the model code
- Mesh axes (right side) determine physical device placement
- JAX automatically inserts communication as needed
Model-specific configurations
Stable Diffusion XL
Recommended for single-host training:Flux
- Data parallelism
- FSDP for large models
Wan2.1
Video models require sequence and head parallelism:Wan supports fractional batch sizes:
per_device_batch_size=0.25 on 4 devices = global batch of 1.GPU-specific parallelism
For Wan on GPU with cuDNN flash attention:The cuDNN flash attention kernel requires sequence length divisible by
ici_fsdp_parallelism (no padding support).Fractional batch sizes
Some models support fractionalper_device_batch_size:
per_device_batch_size * num_devicesmust be a whole number- Not supported with
ici_fsdp_batch_parallelism(GPU batch parallelism)
Common pitfalls
Batch size not divisible
Head count not divisible
Wan2.1 has 40 attention heads. Usingici_tensor_parallelism=3 fails:
ici_tensor_parallelism.
Wrong axis for auto-shard
-1.
Performance tips
Maximize data parallelism first
Start with pure data parallelism and only add model parallelism when necessary:Use context parallelism for long sequences
For videos or high-resolution images:Combine strategies strategically
For maximum efficiency:Debugging sharding
Enable verbose logging to see sharding decisions:Next steps
- Learn about attention mechanisms
- Explore supported models
- Understand MaxDiffusion’s architecture