Skip to main content
MaxDiffusion supports various parallelism strategies for distributed training and inference using JAX’s sharding capabilities.

Mesh configuration

JAX uses a Mesh to define the device topology for distributed computation.

Mesh parameters

mesh
jax.sharding.Mesh
JAX mesh that defines the device layout for distributed computation. Required for models using flash attention or distributed training.
mesh_axes
tuple[str]
Names of mesh axes, typically ("data", "model") or ("fsdp", "tensor")

Creating a mesh

import jax
from jax.sharding import Mesh
import numpy as np

# Get all devices
devices = jax.devices()

# Create 2D mesh for data and model parallelism
devices_array = np.array(devices).reshape(4, 2)  # 4-way data, 2-way model
mesh = Mesh(devices_array, ("data", "model"))

# Use in model
model = FlaxUNet2DConditionModel(
    mesh=mesh,
    attention_kernel="flash",
    # ... other config
)

Logical axis rules

Logical axis rules map logical parameter names to mesh axis names for sharding.

Example rules

from maxdiffusion.common_types import LogicalAxisRules

logical_axis_rules = [
    # Embeddings
    ("embed", "model"),
    
    # Attention
    ("heads", "model"),
    ("kv", None),  # Don't shard key/value dimension
    
    # MLPs
    ("mlp", "model"),
    
    # Activations
    ("activation_batch", "data"),
    ("activation_length", None),
    ("activation_embed", "model"),
    
    # Convolutions
    ("conv_in", None),
    ("conv_out", "model"),
]

Sharding strategies

Data parallelism

Replicates the model across devices, sharding only the batch dimension.
data_sharding
tuple
Sharding specification for data. Example: ("data", None, None, None) shards batch dimension only
from jax.sharding import NamedSharding, PartitionSpec as P

# Shard batch dimension across data axis
data_sharding = NamedSharding(mesh, P("data", None, None, None))
latents = jax.device_put(latents, data_sharding)

Model parallelism (tensor parallelism)

Shards model weights across devices.
# Shard embedding dimension across model axis
param_sharding = NamedSharding(mesh, P(None, "model"))
params = jax.device_put(params, param_sharding)

Fully sharded data parallelism (FSDP)

Combines data and model parallelism by sharding both activations and parameters.
from jax.experimental.pjit import pjit

# Define sharding for inputs and outputs
in_shardings = (
    NamedSharding(mesh, P("data", None)),  # Input data
    NamedSharding(mesh, P("fsdp", "tensor")),  # Parameters
)
out_shardings = NamedSharding(mesh, P("data", None))

# JIT compile with sharding
@pjit(in_shardings=in_shardings, out_shardings=out_shardings)
def forward(x, params):
    return model.apply({"params": params}, x)

Pipeline parallelism

For very large models, pipeline parallelism stages different parts of the model on different devices.
scan_layers
bool
default:"False"
Use JAX scan to stack transformer layers, enabling pipeline parallelism
remat_policy
str
Rematerialization policy for gradient checkpointing. Options:
  • "none": No rematerialization
  • "minimal": Recompute minimal activations
  • "full": Recompute all activations

Quantization configuration

quant
AqtQuantization
AQT quantization configuration for reduced precision. See github.com/google/aqt
use_qwix_quantization
bool
default:"False"
Enable Qwix quantization (alternative to AQT)
quantization
str
Quantization mode for Qwix. Options: "int8", "fp8", "fp8_full"

Block sizes for flash attention

flash_block_sizes
BlockSizes
Custom block sizes for flash attention. Overrides defaults based on hardware.

BlockSizes structure

from maxdiffusion.common_types import BlockSizes

flash_block_sizes = BlockSizes(
    block_q=128,  # Query block size
    block_k=128,  # Key block size  
    block_b=1,    # Batch block size
    block_q_dkv=128,  # Query block for key/value grad
    block_k_dkv=128,  # Key block for key/value grad
    block_k_dq=128,   # Key block for query grad
    block_q_dq=128,   # Query block for query grad
)

Constraints and annotations

Use with_logical_constraint to annotate tensors with sharding constraints:
import flax.linen as nn

# Constrain activation sharding
hidden_states = nn.with_logical_constraint(
    hidden_states,
    ("activation_batch", "activation_length", "activation_embed")
)

Example: Multi-device training setup

import jax
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion.models import FlaxUNet2DConditionModel

# Setup mesh
devices = jax.devices()
devices_array = np.array(devices).reshape(4, 2)  # 8 devices total
mesh = Mesh(devices_array, ("data", "model"))

# Define logical axis rules
logical_axis_rules = [
    ("embed", "model"),
    ("mlp", "model"),
    ("heads", "model"),
    ("activation_batch", "data"),
]

# Create model with mesh
model = FlaxUNet2DConditionModel(
    mesh=mesh,
    attention_kernel="flash",
    flash_min_seq_length=1024,
    dtype=jnp.bfloat16,
    weights_dtype=jnp.bfloat16,
)

# Shard inputs
data_sharding = NamedSharding(mesh, P("data", None, None, None))
latents = jax.device_put(latents, data_sharding)

# Model will automatically shard parameters according to rules

Build docs developers (and LLMs) love