Skip to main content

Overview

MaxDiffusion supports multiple parallelism strategies to scale training and inference across TPU pods and GPU clusters. The framework uses JAX’s sharding APIs to distribute model weights and activations across devices efficiently.

Mesh axes

MaxDiffusion defines a 4-dimensional device mesh for sharding:
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
Each axis represents a different parallelism strategy:
  • data: Data parallelism (replicate model, shard batch)
  • fsdp: Fully Sharded Data Parallelism (shard model parameters)
  • context: Context/sequence parallelism (shard sequence dimension)
  • tensor: Tensor parallelism (shard attention heads and FFN)

Parallelism types

Data parallelism

How it works: Each device has a complete copy of the model and processes a different subset of the batch. Benefits:
  • Simple to implement and debug
  • Excellent scaling for large batch sizes
  • No communication during forward/backward pass
Configuration:
ici_data_parallelism: 4
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
per_device_batch_size: 2  # Global batch = 4 * 2 = 8
Example: 4 TPU chips, each processing 2 samples independently.
Data parallelism is the default strategy and recommended for most workloads when the model fits in device memory.

FSDP (Fully Sharded Data Parallelism)

How it works: Model parameters are sharded across devices, with each device owning a fraction of the weights. During computation, parameters are gathered as needed. Benefits:
  • Trains models larger than single-device memory
  • Reduces memory per device
  • Enables larger batch sizes
Configuration:
ici_data_parallelism: 1
ici_fsdp_parallelism: 4
ici_context_parallelism: 1
ici_tensor_parallelism: 1
Example use case: Training SDXL on limited memory by sharding the U-Net across 4 devices.
FSDP adds communication overhead during forward/backward passes. Use when model doesn’t fit in memory, not for speed optimization.

Context/sequence parallelism

How it works: The sequence dimension is sharded across devices. Each device processes a portion of the sequence in attention layers. Benefits:
  • Enables extremely long sequences (videos, high-resolution images)
  • Reduces memory for attention QKV matrices
  • Critical for video models like Wan
Configuration:
ici_data_parallelism: 1
ici_fsdp_parallelism: 1
ici_context_parallelism: 4
ici_tensor_parallelism: 1
Example: Wan2.1 with 81 frames at 720p:
# Sequence length = (720/8) * (1280/8) * (81/4) ≈ 183,600 tokens
ici_context_parallelism: 4  # Shard sequence across 4 devices
For Wan models, the ici_fsdp_parallelism axis is used for sequence parallelism, while ici_tensor_parallelism is used for head parallelism.

Tensor parallelism

How it works: Individual layers are split across devices. Attention heads and FFN dimensions are partitioned. Benefits:
  • Reduces memory for large layers
  • Can combine with other strategies
  • Useful for models with many attention heads
Configuration:
ici_data_parallelism: 1
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 4
Constraints:
  • Number of attention heads must be divisible by ici_tensor_parallelism
  • For Wan2.1: 40 heads, so use 1, 2, 4, 5, 8, 10, 20, or 40

ICI vs DCN parallelism

MaxDiffusion distinguishes between two network types: ICI (Inter-Chip Interconnect):
  • High-bandwidth, low-latency network within a single TPU pod
  • Use for compute-heavy parallelism (tensor, context)
DCN (Data Center Network):
  • Lower bandwidth between TPU pods
  • Use for data parallelism (less communication)

Configuration example

# ICI: 8 devices per pod
ici_data_parallelism: 2
ici_fsdp_parallelism: 1
ici_context_parallelism: 2
ici_tensor_parallelism: 2

# DCN: 4 pods
dcn_data_parallelism: 4
dcn_fsdp_parallelism: 1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
This configuration creates a (4, 2, 2, 2) = 32 total devices with:
  • Data parallelism across 4 pods × 2 devices = 8-way data parallel
  • Context parallelism across 2 devices within each pod
  • Tensor parallelism across 2 devices within each pod

Auto-sharding

Use -1 to automatically shard an axis:
ici_data_parallelism: -1  # Auto-shard remaining devices
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
Example: On an 8-device pod, this automatically sets ici_data_parallelism: 8.
Only one axis (ICI or DCN) can use auto-sharding. The product of all axes must equal the total number of devices.

Logical axis rules

Logical axis rules map tensor dimensions to mesh axes:
logical_axis_rules: [
  ['batch', ['data', 'fsdp']],              # Batch sharded across data+fsdp
  ['activation_heads', 'tensor'],            # Attention heads on tensor axis
  ['activation_length', 'context'],          # Sequence on context axis
  ['embed', ['context', 'fsdp']],           # Embeddings on context+fsdp
  ['heads', 'tensor'],                      # Weight heads on tensor axis
  ['norm', 'tensor'],                       # Norm params on tensor axis
]

How it works

  1. Tensor dimensions (left side) are named in the model code
  2. Mesh axes (right side) determine physical device placement
  3. JAX automatically inserts communication as needed
Example:
# A batch of shape (8, 1024, 768) with logical axes:
# (batch=8, length=1024, embed=768)

# With ici_data_parallelism=2, ici_context_parallelism=2:
# - batch split across 2 data devices: (4, 1024, 768) per device
# - length split across 2 context devices: (4, 512, 768) per shard

Model-specific configurations

Stable Diffusion XL

Recommended for single-host training:
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
ici_data_parallelism: -1
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1

logical_axis_rules: [
  ['batch', 'data'],
  ['activation_batch', ['data','fsdp']],
  ['activation_heads', 'tensor'],
  ['embed','fsdp'],
  ['heads', 'tensor'],
]

Flux

# 8 devices, batch size 8
ici_data_parallelism: -1
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
per_device_batch_size: 1

Wan2.1

Video models require sequence and head parallelism:
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# 256 devices (v5p-256)
ici_data_parallelism: 32      # 32-way data parallel
ici_fsdp_parallelism: 4       # Sequence parallelism
ici_context_parallelism: 1
ici_tensor_parallelism: 1     # Head parallelism

per_device_batch_size: 0.25   # Fractional batch size
# Global batch = 32 * 0.25 = 8

logical_axis_rules: [
  ['batch', ['data', 'fsdp']],
  ['activation_self_attn_heads', ['context', 'tensor']],
  ['activation_cross_attn_q_length', ['context', 'tensor']],
  ['activation_length', 'context'],
  ['embed', ['context', 'fsdp']],
  ['heads', 'tensor'],
]
Wan supports fractional batch sizes: per_device_batch_size=0.25 on 4 devices = global batch of 1.

GPU-specific parallelism

For Wan on GPU with cuDNN flash attention:
# Batch parallelism (no fractional batch support)
ici_fsdp_batch_parallelism: 2  # GPU-specific batch axis
ici_fsdp_parallelism: 2         # Can combine with sequence
attention: "cudnn_flash_te"
The cuDNN flash attention kernel requires sequence length divisible by ici_fsdp_parallelism (no padding support).

Fractional batch sizes

Some models support fractional per_device_batch_size:
per_device_batch_size: 0.25
ici_data_parallelism: 4
# Effective global batch = 0.25 * 4 = 1
Requirements:
  • per_device_batch_size * num_devices must be a whole number
  • Not supported with ici_fsdp_batch_parallelism (GPU batch parallelism)
Use case: Large models where even batch size 1 per device is too big.

Common pitfalls

Batch size not divisible

Warning: batch dimension should be shardable among the devices in data and fsdp axis
batch dimension: 1, devices_in_data_fsdp: 4
Solution: Adjust batch size or use fractional batch sizes.

Head count not divisible

Wan2.1 has 40 attention heads. Using ici_tensor_parallelism=3 fails:
# Error: 40 heads not divisible by 3
Solution: Use 1, 2, 4, 5, 8, 10, 20, or 40 for ici_tensor_parallelism.

Wrong axis for auto-shard

dcn_data_parallelism: -1
ici_data_parallelism: -1  # Error: can't auto-shard both
Solution: Only one axis can use -1.

Performance tips

Maximize data parallelism first

Start with pure data parallelism and only add model parallelism when necessary:
# Try this first
ici_data_parallelism: -1

# Only if OOM, then try
ici_data_parallelism: 4
ici_fsdp_parallelism: 2

Use context parallelism for long sequences

For videos or high-resolution images:
# Video: 81 frames at 720p
ici_context_parallelism: 4  # Split 183K sequence

Combine strategies strategically

For maximum efficiency:
# 256 device cluster
dcn_data_parallelism: 32    # Data parallel across pods
ici_fsdp_parallelism: 4     # Sequence parallel within pod
ici_context_parallelism: 2  # Additional sequence sharding

Debugging sharding

Enable verbose logging to see sharding decisions:
import jax
jax.config.update('jax_log_compiles', True)
Check actual tensor shardings:
print(f"Parameter sharding: {state.params.sharding}")

Next steps

Build docs developers (and LLMs) love