Skip to main content
MaxDiffusion provides several optimization strategies to maximize training and inference performance on TPU and GPU hardware.

LIBTPU_INIT_ARGS flags

The LIBTPU_INIT_ARGS environment variable configures XLA compiler optimizations for TPU training. These flags control collective operations, memory management, and scheduling behavior. For Wan2.1 training on TPU v5p:
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true \
--xla_tpu_enable_ag_backward_pipelining=true \
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
--xla_tpu_data_parallel_opt_different_sized_ops=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_async_all_to_all=true \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
--xla_max_concurrent_host_send_recv=100 \
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
--xla_latency_hiding_scheduler_rerun=2 \
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
--xla_tpu_assign_all_reduce_scatter_layout=true'
For Wan inference and lighter workloads:
export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true"
For SDXL and Stable Diffusion training, you can disable these flags:
export LIBTPU_INIT_ARGS=""

Key flags explained

FlagPurpose
xla_tpu_enable_async_collective_fusionEnables fusion of async collective operations for better performance
xla_enable_async_all_gatherAllows all-gather operations to run asynchronously
xla_tpu_scoped_vmem_limit_kibSets virtual memory limit (65536 KiB = 64 MB)
xla_tpu_enable_scheduler_memory_pressure_trackingOptimizes scheduler based on memory usage
xla_latency_hiding_scheduler_rerunReruns scheduler optimization passes

Flash attention block sizes

Flash attention block sizes significantly impact memory usage and performance. Different TPU generations require different configurations.

TPU v6e (Trillium) - Wan models

flash_block_sizes='{
  "block_q" : 3024,
  "block_kv_compute" : 1024,
  "block_kv" : 2048,
  "block_q_dkv" : 3024,
  "block_kv_dkv" : 2048,
  "block_kv_dkv_compute" : 1024,
  "block_q_dq" : 3024,
  "block_kv_dq" : 2048,
  "use_fused_bwd_kernel": False
}'

TPU v5p - Wan models

flash_block_sizes='{
  "block_q" : 3024,
  "block_kv_compute" : 1024,
  "block_kv" : 2048,
  "block_q_dkv" : 1024,
  "block_kv_dkv" : 3072,
  "block_kv_dkv_compute" : 256,
  "block_q_dq" : 1024,
  "block_kv_dq" : 3072
}'

Default configuration

For other models or when unsure:
flash_block_sizes: {
  "block_q" : 512,
  "block_kv_compute" : 512,
  "block_kv" : 512,
  "block_q_dkv" : 512,
  "block_kv_dkv" : 512,
  "block_kv_dkv_compute" : 512,
  "block_q_dq" : 512,
  "block_kv_dq" : 512,
  "use_fused_bwd_kernel": False
}

Setting flash attention

Enable flash attention in your training command:
python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention='flash' \
  flash_min_seq_length=0 \
  flash_block_sizes='{...}'

Remat policies

Gradient checkpointing (rematerialization) trades computation for memory. MaxDiffusion supports several remat policies.

Available policies

  • NONE - No gradient checkpointing (fastest, highest memory usage)
  • FULL - Full gradient checkpointing (slowest, lowest memory usage)
  • MATMUL_WITHOUT_BATCH - Checkpoint linear/matmul operations except those involving batch dimension
  • OFFLOAD_MATMUL_WITHOUT_BATCH - Same as MATMUL_WITHOUT_BATCH but offloads to HBM instead of recomputing
  • HIDDEN_STATE_WITH_OFFLOAD - Offloads hidden states (recommended for Wan training)
  • CUSTOM - Define specific operations to save or offload

Configuration

Set the remat policy in your config file or command line:
remat_policy: 'HIDDEN_STATE_WITH_OFFLOAD'
Or via command line:
python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  remat_policy='HIDDEN_STATE_WITH_OFFLOAD'

Custom policy

For fine-grained control, use CUSTOM policy:
remat_policy: "CUSTOM"
names_which_can_be_saved: ['attn_output', 'query_proj']
names_which_can_be_offloaded: ['xq_out', 'xk_out', 'ffn_activation']
Available annotations: attn_output, query_proj, key_proj, value_proj, xq_out, xk_out, ffn_activation

Data type optimization

Weight and activation dtypes

Choose dtypes based on your hardware and quality requirements:
# Recommended for TPU v5p/v6e
weights_dtype: bfloat16
activations_dtype: bfloat16

# For higher precision (slower)
weights_dtype: float32
activations_dtype: float32

Precision settings

Control matmul and conv precision:
# Options: DEFAULT, HIGH, HIGHEST
precision: "DEFAULT"  # Fastest
precision: "HIGHEST"  # Most accurate with fp32

Parallelism strategies

Wan models

Wan2.1 uses specialized parallelism:
  • ici_fsdp_parallelism - Sequence parallelism (try 2 or 4)
  • ici_tensor_parallelism - Head parallelism (must divide 40 evenly)
  • ici_data_parallelism - Data parallelism
python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  per_device_batch_size=0.25 \
  ici_data_parallelism=32 \
  ici_fsdp_parallelism=4 \
  ici_tensor_parallelism=1

SDXL and Stable Diffusion

python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  per_device_batch_size=1 \
  ici_data_parallelism=-1  # Auto-shard

Performance tips

Fractional batch sizes

Wan training supports fractional batch sizes:
per_device_batch_size: 0.25  # Effective global batch = 0.25 * num_devices
The result must be a whole number.

Caching latents

For faster training on small datasets, cache VAE latents and text encoder outputs:
cache_latents_text_encoder_outputs: True
dataset_save_location: '/tmp/cached_dataset'

HuggingFace transfer acceleration

Speed up model downloads:
export HF_HUB_ENABLE_HF_TRANSFER=1

JAX compilation cache

Avoid recompilation across runs:
python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  jax_cache_dir=gs://your-bucket/jax_cache/

GPU-specific optimizations

Fused attention with Transformer Engine

For NVIDIA GPUs, use cudnn_flash_te attention:
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  hardware=gpu \
  attention="cudnn_flash_te" \
  weights_dtype=bfloat16

Batch parallelism

Enable batch parallelism on GPUs:
ici_fsdp_batch_parallelism: 2  # Does not support fractional batch sizes

Build docs developers (and LLMs) love