LIBTPU_INIT_ARGS flags
TheLIBTPU_INIT_ARGS environment variable configures XLA compiler optimizations for TPU training. These flags control collective operations, memory management, and scheduling behavior.
Recommended configuration
For Wan2.1 training on TPU v5p:Key flags explained
| Flag | Purpose |
|---|---|
xla_tpu_enable_async_collective_fusion | Enables fusion of async collective operations for better performance |
xla_enable_async_all_gather | Allows all-gather operations to run asynchronously |
xla_tpu_scoped_vmem_limit_kib | Sets virtual memory limit (65536 KiB = 64 MB) |
xla_tpu_enable_scheduler_memory_pressure_tracking | Optimizes scheduler based on memory usage |
xla_latency_hiding_scheduler_rerun | Reruns scheduler optimization passes |
Flash attention block sizes
Flash attention block sizes significantly impact memory usage and performance. Different TPU generations require different configurations.TPU v6e (Trillium) - Wan models
TPU v5p - Wan models
Default configuration
For other models or when unsure:Setting flash attention
Enable flash attention in your training command:Remat policies
Gradient checkpointing (rematerialization) trades computation for memory. MaxDiffusion supports several remat policies.Available policies
- NONE - No gradient checkpointing (fastest, highest memory usage)
- FULL - Full gradient checkpointing (slowest, lowest memory usage)
- MATMUL_WITHOUT_BATCH - Checkpoint linear/matmul operations except those involving batch dimension
- OFFLOAD_MATMUL_WITHOUT_BATCH - Same as MATMUL_WITHOUT_BATCH but offloads to HBM instead of recomputing
- HIDDEN_STATE_WITH_OFFLOAD - Offloads hidden states (recommended for Wan training)
- CUSTOM - Define specific operations to save or offload
Configuration
Set the remat policy in your config file or command line:Custom policy
For fine-grained control, use CUSTOM policy:attn_output, query_proj, key_proj, value_proj, xq_out, xk_out, ffn_activation
Data type optimization
Weight and activation dtypes
Choose dtypes based on your hardware and quality requirements:Precision settings
Control matmul and conv precision:Parallelism strategies
Wan models
Wan2.1 uses specialized parallelism:ici_fsdp_parallelism- Sequence parallelism (try 2 or 4)ici_tensor_parallelism- Head parallelism (must divide 40 evenly)ici_data_parallelism- Data parallelism