MaxDiffusion supports various parallelism strategies for distributed training and inference using JAX’s sharding capabilities.
Mesh configuration
JAX uses a Mesh to define the device topology for distributed computation.
Mesh parameters
JAX mesh that defines the device layout for distributed computation. Required for models using flash attention or distributed training.
Names of mesh axes, typically ("data", "model") or ("fsdp", "tensor")
Creating a mesh
import jax
from jax.sharding import Mesh
import numpy as np
# Get all devices
devices = jax.devices()
# Create 2D mesh for data and model parallelism
devices_array = np.array(devices).reshape(4, 2) # 4-way data, 2-way model
mesh = Mesh(devices_array, ("data", "model"))
# Use in model
model = FlaxUNet2DConditionModel(
mesh=mesh,
attention_kernel="flash",
# ... other config
)
Logical axis rules
Logical axis rules map logical parameter names to mesh axis names for sharding.
Example rules
from maxdiffusion.common_types import LogicalAxisRules
logical_axis_rules = [
# Embeddings
("embed", "model"),
# Attention
("heads", "model"),
("kv", None), # Don't shard key/value dimension
# MLPs
("mlp", "model"),
# Activations
("activation_batch", "data"),
("activation_length", None),
("activation_embed", "model"),
# Convolutions
("conv_in", None),
("conv_out", "model"),
]
Sharding strategies
Data parallelism
Replicates the model across devices, sharding only the batch dimension.
Sharding specification for data. Example: ("data", None, None, None) shards batch dimension only
from jax.sharding import NamedSharding, PartitionSpec as P
# Shard batch dimension across data axis
data_sharding = NamedSharding(mesh, P("data", None, None, None))
latents = jax.device_put(latents, data_sharding)
Model parallelism (tensor parallelism)
Shards model weights across devices.
# Shard embedding dimension across model axis
param_sharding = NamedSharding(mesh, P(None, "model"))
params = jax.device_put(params, param_sharding)
Fully sharded data parallelism (FSDP)
Combines data and model parallelism by sharding both activations and parameters.
from jax.experimental.pjit import pjit
# Define sharding for inputs and outputs
in_shardings = (
NamedSharding(mesh, P("data", None)), # Input data
NamedSharding(mesh, P("fsdp", "tensor")), # Parameters
)
out_shardings = NamedSharding(mesh, P("data", None))
# JIT compile with sharding
@pjit(in_shardings=in_shardings, out_shardings=out_shardings)
def forward(x, params):
return model.apply({"params": params}, x)
Pipeline parallelism
For very large models, pipeline parallelism stages different parts of the model on different devices.
Use JAX scan to stack transformer layers, enabling pipeline parallelism
Rematerialization policy for gradient checkpointing. Options:
"none": No rematerialization
"minimal": Recompute minimal activations
"full": Recompute all activations
Quantization configuration
Enable Qwix quantization (alternative to AQT)
Quantization mode for Qwix. Options: "int8", "fp8", "fp8_full"
Block sizes for flash attention
Custom block sizes for flash attention. Overrides defaults based on hardware.
BlockSizes structure
from maxdiffusion.common_types import BlockSizes
flash_block_sizes = BlockSizes(
block_q=128, # Query block size
block_k=128, # Key block size
block_b=1, # Batch block size
block_q_dkv=128, # Query block for key/value grad
block_k_dkv=128, # Key block for key/value grad
block_k_dq=128, # Key block for query grad
block_q_dq=128, # Query block for query grad
)
Constraints and annotations
Use with_logical_constraint to annotate tensors with sharding constraints:
import flax.linen as nn
# Constrain activation sharding
hidden_states = nn.with_logical_constraint(
hidden_states,
("activation_batch", "activation_length", "activation_embed")
)
Example: Multi-device training setup
import jax
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion.models import FlaxUNet2DConditionModel
# Setup mesh
devices = jax.devices()
devices_array = np.array(devices).reshape(4, 2) # 8 devices total
mesh = Mesh(devices_array, ("data", "model"))
# Define logical axis rules
logical_axis_rules = [
("embed", "model"),
("mlp", "model"),
("heads", "model"),
("activation_batch", "data"),
]
# Create model with mesh
model = FlaxUNet2DConditionModel(
mesh=mesh,
attention_kernel="flash",
flash_min_seq_length=1024,
dtype=jnp.bfloat16,
weights_dtype=jnp.bfloat16,
)
# Shard inputs
data_sharding = NamedSharding(mesh, P("data", None, None, None))
latents = jax.device_put(latents, data_sharding)
# Model will automatically shard parameters according to rules