Overview
Attention is the computational bottleneck in diffusion models. MaxDiffusion provides multiple attention implementations optimized for different hardware accelerators, sequence lengths, and memory constraints.Supported attention kernels
MaxDiffusion supports five attention implementations:| Kernel | Hardware | Use case | Memory efficiency |
|---|---|---|---|
dot_product | TPU, GPU | Short sequences, debugging | Standard |
flash | TPU | All sequence lengths | High |
tokamax_flash | TPU | Advanced TPU optimization | Highest |
cudnn_flash_te | GPU | GPU with Transformer Engine | High |
ring | TPU | Extremely long sequences | Highest |
Dot product attention
Standard attention implementation using matrix multiplication.Configuration
How it works
- Reshape: Convert from
[B, S, H*D]to[B, H, S, D] - QK matmul: Compute
attention_scores = Q @ K^T - Scale: Multiply by
1/sqrt(head_dim) - Softmax: Normalize attention weights
- Weighted sum: Compute
output = attention_probs @ V
Memory-efficient variant
For long sequences, use chunked attention:When to use
- Short sequences (< 4K tokens)
- Debugging (easier to understand than flash attention)
- Compatibility (works on all hardware without special kernels)
Dot product attention has O(N²) memory complexity. For sequences > 4K, use flash attention.
Flash attention (TPU)
Memory-efficient attention using Pallas kernels optimized for TPU.Configuration
Block sizes explained
- Forward pass
- Backward pass (fused)
- Backward pass (unfused)
- block_q: Block size for Q sequence (HBM → VMEM → VREG)
- block_kv: Block size for K/V sequence (HBM → VMEM)
- block_kv_compute: Sub-block for K/V computation (VMEM → VREG)
block_kv_compute ≤ block_kvTuning block sizes
Block sizes significantly impact performance. Optimize for your hardware:How block sizes affect performance
HBM bandwidth saturation:- Larger blocks → fewer memory transfers → better bandwidth utilization
- But: blocks must fit in VMEM (vector memory)
- Sequences are padded to multiples of block sizes
- Choose block sizes that divide your sequence length evenly
Minimum sequence length
Control when to use flash attention:- If
sequence_length >= flash_min_seq_length: Use flash attention - Otherwise: Fall back to dot product attention
Set
flash_min_seq_length: 0 for video models and high-resolution images. Flash attention is faster than dot product even for short sequences on modern TPUs.Padding token masking
True:
- Passes segment IDs to splash attention
- Avoids attending to padding tokens
- Slightly slower but better quality
False:
- No segment IDs passed
- Faster on VPU-bound hardware (Trillium)
- Use only when padding is minimal
When to use
- All sequence lengths on TPU (v4, v5p, v6e)
- Training and inference for all supported models
- Default choice for TPU workloads
Tokamax flash attention
Advanced flash attention using the Tokamax library with fused backward pass.Configuration
Differences from standard flash
- Fused backward pass: Computes dQ, dK, dV in a single kernel
- Better memory: Slightly lower HBM usage during training
- No unfused blocks:
block_q_dqandblock_kv_dqare ignored
Tokamax flash attention requires
use_fused_bwd_kernel: True. The unfused backward pass is not supported.When to use
- TPU training when maximum memory efficiency is needed
- Large models where standard flash attention OOMs
- Same performance as standard flash, but slightly better memory
cuDNN flash attention (GPU)
GPU-optimized attention using NVIDIA Transformer Engine.Installation
Configuration
Running
The
NVTE_FUSED_ATTN=1 environment variable enables fused attention in Transformer Engine.GPU-specific parallelism
For Wan models on GPU:- Sequence length must be divisible by
ici_fsdp_parallelism(no padding) - Fractional batch sizes not supported with batch parallelism
When to use
- GPU training and inference (H100, A100)
- All supported models on GPU
- Best GPU performance with Transformer Engine
Ring attention
Distributed attention for extremely long sequences across multiple devices.Configuration
How it works
- Shard K/V: Split key and value across devices on context axis
- Local attention: Each device computes attention with its local K/V shard
- Ring exchange: Rotate K/V shards between devices
- Aggregate: Combine partial attention outputs using LSE (log-sum-exp)
Requirements
ici_context_parallelism > 1use_fused_bwd_kernel: True- Sequence must be shardable across context axis
When to use
- Extremely long sequences (> 1M tokens)
- Video models with many frames
- Memory-constrained training
Ring attention adds communication overhead. Only use when sequences don’t fit with standard flash attention.
Attention flowchart
MaxDiffusion automatically adjusts block sizes following this flowchart:Automatic adjustments
MaxDiffusion logs modifications to block sizes:Performance comparison
TPU (v5p-8) - Wan2.1 training
| Attention | Step time | Memory | TFLOP/s/device |
|---|---|---|---|
| dot_product | 142s | OOM | - |
| flash | 36s | 95% | 135.9 |
| tokamax_flash | 36s | 93% | 135.9 |
GPU (H100) - SDXL training
| Attention | Step time | Memory |
|---|---|---|
| dot_product | 3.2s | 78 GB |
| cudnn_flash_te | 1.8s | 62 GB |
Debugging attention
Enable attention logging
Check actual shapes
Add prints inattention_flax.py:524:
Profile attention
Common issues
Sequence not divisible by block size
Symptom: Excessive padding, high memory usage Solution: Choose block sizes that divide sequence length:Out of memory with flash attention
Symptom: OOM during attention forward/backward Solutions:- Reduce block sizes (less VMEM usage)
- Enable gradient checkpointing:
remat_policy: "FULL" - Use tokamax flash:
attention: "tokamax_flash" - Use ring attention:
attention: "ring"
cuDNN flash not working
Symptom: Error about missing Transformer Engine Solution:Quality degradation with mask_padding_tokens=False
Symptom: Blurry outputs, attention to padding Solution: Setmask_padding_tokens: True if padding is significant (> 10% of sequence).
Best practices
- Use flash attention by default on TPU
- Tune block sizes for your sequence length
- Set flash_min_seq_length: 0 for video models
- Enable mask_padding_tokens for variable-length sequences
- Use cudnn_flash_te on GPU with Transformer Engine
- Profile attention kernels to verify performance
Next steps
- Configure parallelism strategies
- Explore supported models
- Understand MaxDiffusion’s architecture