Skip to main content

Overview

Attention is the computational bottleneck in diffusion models. MaxDiffusion provides multiple attention implementations optimized for different hardware accelerators, sequence lengths, and memory constraints.

Supported attention kernels

MaxDiffusion supports five attention implementations:
KernelHardwareUse caseMemory efficiency
dot_productTPU, GPUShort sequences, debuggingStandard
flashTPUAll sequence lengthsHigh
tokamax_flashTPUAdvanced TPU optimizationHighest
cudnn_flash_teGPUGPU with Transformer EngineHigh
ringTPUExtremely long sequencesHighest

Dot product attention

Standard attention implementation using matrix multiplication.

Configuration

attention: 'dot_product'
split_head_dim: True
float32_qk_product: True

How it works

  1. Reshape: Convert from [B, S, H*D] to [B, H, S, D]
  2. QK matmul: Compute attention_scores = Q @ K^T
  3. Scale: Multiply by 1/sqrt(head_dim)
  4. Softmax: Normalize attention weights
  5. Weighted sum: Compute output = attention_probs @ V

Memory-efficient variant

For long sequences, use chunked attention:
attention: 'dot_product'
use_memory_efficient_attention: True
This implementation chunks the query and key/value sequences to reduce peak memory usage. Configuration:
# Automatic chunk size based on sequence length
query_chunk_size = sequence_length // 64  # Adaptive
key_chunk_size = 4096 * 4                # Fixed

When to use

  • Short sequences (< 4K tokens)
  • Debugging (easier to understand than flash attention)
  • Compatibility (works on all hardware without special kernels)
Dot product attention has O(N²) memory complexity. For sequences > 4K, use flash attention.

Flash attention (TPU)

Memory-efficient attention using Pallas kernels optimized for TPU.

Configuration

attention: 'flash'
flash_min_seq_length: 0
mask_padding_tokens: True

flash_block_sizes: {
  "block_q": 512,
  "block_kv_compute": 512,
  "block_kv": 512,
  "block_q_dkv": 512,
  "block_kv_dkv": 512,
  "block_kv_dkv_compute": 512,
  "block_q_dq": 512,
  "block_kv_dq": 512,
  "use_fused_bwd_kernel": False
}

Block sizes explained

  • block_q: Block size for Q sequence (HBM → VMEM → VREG)
  • block_kv: Block size for K/V sequence (HBM → VMEM)
  • block_kv_compute: Sub-block for K/V computation (VMEM → VREG)
Must satisfy: block_kv_compute ≤ block_kv

Tuning block sizes

Block sizes significantly impact performance. Optimize for your hardware:
flash_block_sizes: {
  "block_q": 512,
  "block_kv_compute": 512,
  "block_kv": 512,
  "block_q_dkv": 512,
  "block_kv_dkv": 512,
  "block_kv_dkv_compute": 512,
  "block_q_dq": 512,
  "block_kv_dq": 512,
  "use_fused_bwd_kernel": False
}

How block sizes affect performance

HBM bandwidth saturation:
  • Larger blocks → fewer memory transfers → better bandwidth utilization
  • But: blocks must fit in VMEM (vector memory)
Sequence length padding:
  • Sequences are padded to multiples of block sizes
  • Choose block sizes that divide your sequence length evenly
Example:
# Wan 720p: sequence length ≈ 183,600
block_q: 3024  # 183,600 / 3024 ≈ 61 blocks (good)

# Bad choice:
block_q: 1000  # Requires padding to 184,000 (wasteful)

Minimum sequence length

Control when to use flash attention:
flash_min_seq_length: 0  # Always use flash (recommended)
Behavior:
  • If sequence_length >= flash_min_seq_length: Use flash attention
  • Otherwise: Fall back to dot product attention
Set flash_min_seq_length: 0 for video models and high-resolution images. Flash attention is faster than dot product even for short sequences on modern TPUs.

Padding token masking

mask_padding_tokens: True
When True:
  • Passes segment IDs to splash attention
  • Avoids attending to padding tokens
  • Slightly slower but better quality
When False:
  • No segment IDs passed
  • Faster on VPU-bound hardware (Trillium)
  • Use only when padding is minimal

When to use

  • All sequence lengths on TPU (v4, v5p, v6e)
  • Training and inference for all supported models
  • Default choice for TPU workloads

Tokamax flash attention

Advanced flash attention using the Tokamax library with fused backward pass.

Configuration

attention: 'tokamax_flash'

flash_block_sizes: {
  "block_q": 512,
  "block_kv_compute": 512,
  "block_kv": 512,
  "block_q_dkv": 512,
  "block_kv_dkv": 512,
  "block_kv_dkv_compute": 512,
  "use_fused_bwd_kernel": True  # Required for tokamax
}

Differences from standard flash

  1. Fused backward pass: Computes dQ, dK, dV in a single kernel
  2. Better memory: Slightly lower HBM usage during training
  3. No unfused blocks: block_q_dq and block_kv_dq are ignored
Tokamax flash attention requires use_fused_bwd_kernel: True. The unfused backward pass is not supported.

When to use

  • TPU training when maximum memory efficiency is needed
  • Large models where standard flash attention OOMs
  • Same performance as standard flash, but slightly better memory

cuDNN flash attention (GPU)

GPU-optimized attention using NVIDIA Transformer Engine.

Installation

pip install -U "jax[cuda12]"
pip install "transformer_engine[jax]"

Configuration

attention: 'cudnn_flash_te'
hardware: 'gpu'
split_head_dim: True

Running

NVTE_FUSED_ATTN=1 python src/maxdiffusion/train_sdxl.py \
  src/maxdiffusion/configs/base_xl.yml \
  attention="cudnn_flash_te" \
  hardware=gpu
The NVTE_FUSED_ATTN=1 environment variable enables fused attention in Transformer Engine.

GPU-specific parallelism

For Wan models on GPU:
attention: "cudnn_flash_te"
ici_fsdp_batch_parallelism: 2  # Batch parallelism (no fractional)
ici_fsdp_parallelism: 2         # Sequence parallelism
Constraints:
  • Sequence length must be divisible by ici_fsdp_parallelism (no padding)
  • Fractional batch sizes not supported with batch parallelism

When to use

  • GPU training and inference (H100, A100)
  • All supported models on GPU
  • Best GPU performance with Transformer Engine

Ring attention

Distributed attention for extremely long sequences across multiple devices.

Configuration

attention: 'ring'
ici_context_parallelism: 4  # Must be > 1

flash_block_sizes: {
  "use_fused_bwd_kernel": True  # Required
}

How it works

  1. Shard K/V: Split key and value across devices on context axis
  2. Local attention: Each device computes attention with its local K/V shard
  3. Ring exchange: Rotate K/V shards between devices
  4. Aggregate: Combine partial attention outputs using LSE (log-sum-exp)
Visualization:
Device 0: Q @ K0/V0 → output0, lse0
Device 1: Q @ K1/V1 → output1, lse1
...

# Ring: rotate K/V
Device 0: Q @ K1/V1 → output0', lse0'
...

# Combine using log-sum-exp
final_output = stable_combine([output0, output0', ...], [lse0, lse0', ...])

Requirements

  • ici_context_parallelism > 1
  • use_fused_bwd_kernel: True
  • Sequence must be shardable across context axis

When to use

  • Extremely long sequences (> 1M tokens)
  • Video models with many frames
  • Memory-constrained training
Ring attention adds communication overhead. Only use when sequences don’t fit with standard flash attention.

Attention flowchart

MaxDiffusion automatically adjusts block sizes following this flowchart:

Automatic adjustments

MaxDiffusion logs modifications to block sizes:
Modified block sizes for cross-attention:
  block_kv: 512 → 256 (text sequence length)
  block_kv_compute: 512 → 256

Performance comparison

TPU (v5p-8) - Wan2.1 training

AttentionStep timeMemoryTFLOP/s/device
dot_product142sOOM-
flash36s95%135.9
tokamax_flash36s93%135.9

GPU (H100) - SDXL training

AttentionStep timeMemory
dot_product3.2s78 GB
cudnn_flash_te1.8s62 GB

Debugging attention

Enable attention logging

enable_jax_named_scopes: True
This adds named scopes around attention operations for profiler visibility.

Check actual shapes

Add prints in attention_flax.py:524:
print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape: {value.shape}")

Profile attention

enable_profiler: True
profiler_steps: 3
skip_first_n_steps_for_profiler: 5
View traces in TensorBoard to see attention kernel performance.

Common issues

Sequence not divisible by block size

Symptom: Excessive padding, high memory usage Solution: Choose block sizes that divide sequence length:
# Sequence length = 183,600
block_q: 3024  # 183,600 / 3024 = 60.7 → padded to 61 blocks
block_q: 2048  # 183,600 / 2048 = 89.6 → worse padding

Out of memory with flash attention

Symptom: OOM during attention forward/backward Solutions:
  1. Reduce block sizes (less VMEM usage)
  2. Enable gradient checkpointing: remat_policy: "FULL"
  3. Use tokamax flash: attention: "tokamax_flash"
  4. Use ring attention: attention: "ring"

cuDNN flash not working

Symptom: Error about missing Transformer Engine Solution:
pip install "transformer_engine[jax]"
export NVTE_FUSED_ATTN=1

Quality degradation with mask_padding_tokens=False

Symptom: Blurry outputs, attention to padding Solution: Set mask_padding_tokens: True if padding is significant (> 10% of sequence).

Best practices

  1. Use flash attention by default on TPU
  2. Tune block sizes for your sequence length
  3. Set flash_min_seq_length: 0 for video models
  4. Enable mask_padding_tokens for variable-length sequences
  5. Use cudnn_flash_te on GPU with Transformer Engine
  6. Profile attention kernels to verify performance

Next steps

Build docs developers (and LLMs) love