Skip to main content
This guide covers common issues you may encounter when using MaxDiffusion and how to resolve them.

Compilation issues

Symptoms: Training or inference hangs during compilation, or compilation takes over 30 minutes.Solutions:
  1. Use JAX compilation cache to avoid recompiling:
    python src/maxdiffusion/train_wan.py \
      src/maxdiffusion/configs/base_wan_14b.yml \
      jax_cache_dir=gs://your-bucket/jax_cache/
    
  2. Reduce model or batch size during initial testing:
    per_device_batch_size=0.125  # Smaller batch for faster compilation
    
  3. Check LIBTPU_INIT_ARGS - some flag combinations can slow compilation:
    # Try disabling all flags first
    export LIBTPU_INIT_ARGS=""
    
  4. Enable profiler to see where it’s stuck:
    enable_profiler: True
    skip_first_n_steps_for_profiler: 1
    
Symptoms: Errors like “Shape mismatch” or “XLA compilation failed”.Solutions:
  1. Verify parallelism settings match your hardware:
    # Check that product of ICI axes equals devices per slice
    ici_data_parallelism=2
    ici_fsdp_parallelism=4  # 2 * 4 = 8 devices
    ici_tensor_parallelism=1
    
  2. Check batch size divisibility:
    # Global batch must be evenly divisible by (data * fsdp) parallelism
    per_device_batch_size * num_devices % (ici_data_parallelism * ici_fsdp_parallelism) == 0
    
  3. For Wan models, verify head parallelism divides 40:
    # Valid values: 1, 2, 4, 5, 8, 10, 20, 40
    ici_tensor_parallelism=5  # OK
    ici_tensor_parallelism=3  # ERROR: 40 % 3 != 0
    
  4. Disable jit_initializers for debugging:
    jit_initializers: False  # Only for single-host debugging
    
Symptoms: Errors about bfloat16/float32 incompatibility.Solutions:
  1. Match weights and activations dtypes:
    weights_dtype: bfloat16
    activations_dtype: bfloat16
    
  2. Use float32 for higher precision (slower):
    weights_dtype: float32
    activations_dtype: float32
    precision: "HIGHEST"
    
  3. For GPU, ensure Transformer Engine is installed when using cudnn_flash_te:
    pip install "transformer_engine[jax]"
    NVTE_FUSED_ATTN=1 python src/maxdiffusion/train_sdxl.py ...
    

Out of memory (OOM) errors

Symptoms: “Out of memory” or “HBM allocation failed” errors.Solutions:
  1. Reduce batch size:
    per_device_batch_size=0.125  # Or even smaller like 0.0625
    
  2. Enable gradient checkpointing (rematerialization):
    remat_policy: "HIDDEN_STATE_WITH_OFFLOAD"  # For Wan
    remat_policy: "FULL"  # For maximum memory savings
    
  3. Use smaller flash block sizes:
    flash_block_sizes: {
      "block_q" : 512,
      "block_kv_compute" : 512,
      "block_kv" : 512,
      "block_q_dkv" : 512,
      "block_kv_dkv" : 512,
      "block_kv_dkv_compute" : 512,
      "block_q_dq" : 512,
      "block_kv_dq" : 512
    }
    
  4. Reduce resolution or number of frames:
    # For Wan models
    height=720  # Instead of 1280
    width=480   # Instead of 720
    num_frames=49  # Instead of 81
    
  5. Increase FSDP parallelism to shard model across more devices:
    ici_fsdp_parallelism=8  # More sharding = less memory per device
    
  6. For Wan, adjust scoped_vmem_limit:
    export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=32768"  # Reduce from 65536
    
Symptoms: OOM when loading pretrained weights.Solutions:
  1. Enable single replica checkpoint restoring:
    enable_single_replica_ckpt_restoring: True
    
  2. For Wan models, use external disk for HuggingFace cache:
    HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python ...
    
  3. Load weights in bfloat16:
    weights_dtype: bfloat16
    from_pt: True
    
Symptoms: OOM when creating TFRecord datasets.Solutions:
  1. Process in smaller batches:
    # In wan_txt2vid_data_preprocessing.py, reduce batch_size
    batch_size = 5  # Default is 10
    
  2. Increase number of shards:
    no_records_per_shard=5  # Smaller shards = less memory
    
  3. Use streaming dataset instead of in-memory:
    dataset_type: hf  # Instead of tf
    

Disk space issues

Symptoms: “No space left on device” errors.Solutions:
  1. Attach external disk to VM:
    # Follow: https://cloud.google.com/tpu/docs/attach-durable-block-storage
    # Then mount and use for cache:
    HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
    
  2. Save checkpoints to GCS instead of local disk:
    output_dir: gs://my-bucket/checkpoints/
    jax_cache_dir: gs://my-bucket/jax_cache/
    
  3. Disable checkpoint saving during debugging:
    checkpoint_every: -1
    save_final_checkpoint: False
    
  4. Clean up HuggingFace cache:
    rm -rf ~/.cache/huggingface/hub/*
    # Or set cache to GCS bucket
    
  5. Use smaller dataset or streaming:
    dataset_type: hf  # Streams data without downloading
    max_train_samples: 1000  # Limit dataset size
    
Symptoms: Disk full when downloading datasets from HuggingFace.Solutions:
  1. Use streaming dataset:
    dataset_type: hf  # No download needed
    dataset_name: BleachNick/UltraEdit_500k
    
  2. Download to external disk:
    export HF_DATASET_DIR=/mnt/disks/external_disk/datasets/
    huggingface-cli download RaphaelLiu/PusaV1_training --local-dir $HF_DATASET_DIR
    
  3. Download directly to GCS:
    # Download locally first, then upload and delete
    huggingface-cli download ... --local-dir /tmp/dataset
    gsutil -m cp -r /tmp/dataset gs://my-bucket/
    rm -rf /tmp/dataset
    

Permission and access errors

Symptoms: “401 Client Error: Unauthorized” or “Access denied”.Solutions:
  1. Obtain access to the model on HuggingFace (e.g., Flux, Wan).
  2. Create HuggingFace token:
  3. Set token in config or environment:
    hf_access_token: 'hf_xxxxxxxxxxxxxxxxxxxx'
    
    Or:
    export HF_TOKEN='hf_xxxxxxxxxxxxxxxxxxxx'
    huggingface-cli login --token $HF_TOKEN
    
Symptoms: “403 Forbidden” or “Permission denied” when accessing GCS buckets.Solutions:
  1. Authenticate gcloud:
    gcloud auth login
    gcloud auth application-default login
    
  2. Set project:
    gcloud config set project YOUR_PROJECT_ID
    
  3. Grant VM service account permissions:
    # Give Storage Admin role to TPU service account
    gcloud projects add-iam-policy-binding YOUR_PROJECT_ID \
      --member serviceAccount:SERVICE_ACCOUNT_EMAIL \
      --role roles/storage.admin
    
  4. Check bucket exists and is accessible:
    gsutil ls gs://my-bucket/
    
Symptoms: “Permission denied” when saving checkpoints locally.Solutions:
  1. Check directory permissions:
    ls -la /tmp/
    chmod 777 /tmp/output  # Or appropriate permissions
    
  2. Use home directory or /tmp:
    output_dir: /tmp/checkpoints/
    dataset_save_location: /tmp/dataset/
    
  3. Run with appropriate user:
    sudo chown -R $USER:$USER /path/to/output
    

Training and inference issues

Symptoms: Loss shows as NaN or increases dramatically.Solutions:
  1. Reduce learning rate:
    learning_rate: 1.e-6  # Instead of 1.e-5
    
  2. Enable gradient clipping:
    max_grad_norm: 1.0  # Default, try 0.5 for more aggressive clipping
    
  3. Use float32 instead of bfloat16:
    weights_dtype: float32
    activations_dtype: float32
    
  4. Check data preprocessing - ensure images/videos are normalized correctly.
  5. Reduce batch size - very large batches can cause instability.
Symptoms: Outputs are blurry, distorted, or don’t match prompts.Solutions:
  1. Increase inference steps:
    num_inference_steps=50  # Instead of 20
    
  2. Adjust guidance scale:
    guidance_scale=7.5  # Try values between 5-15
    
  3. For Wan models, set flow_shift:
    flow_shift=5.0  # Wan2.1 recommended value
    
  4. Use higher precision:
    weights_dtype: float32
    activations_dtype: float32
    
  5. Check if model loaded correctly - verify checkpoint path and weights.
Symptoms: Step time is much slower than expected.Solutions:
  1. Enable flash attention:
    attention='flash'
    flash_min_seq_length=0
    
  2. Optimize LIBTPU_INIT_ARGS - see optimization guide.
  3. Use appropriate flash block sizes for your TPU generation.
  4. Cache latents and text encodings:
    cache_latents_text_encoder_outputs: True
    
  5. Enable profiler to identify bottlenecks:
    enable_profiler: True
    skip_first_n_steps_for_profiler: 5
    profiler_steps: 10
    
  6. For GPU, use fused attention:
    NVTE_FUSED_ATTN=1 python ... attention="cudnn_flash_te"
    

Multihost issues

Symptoms: Training hangs when running on multiple hosts.Solutions:
  1. Enable distributed system initialization:
    skip_jax_distributed_system: False
    
  2. Ensure all hosts have same code version:
    # On all workers:
    cd maxdiffusion && git pull && pip install -e .
    
  3. Check DCN parallelism settings:
    dcn_data_parallelism=-1  # Auto-shard across slices
    dcn_fsdp_parallelism=1
    dcn_tensor_parallelism=1
    
  4. Verify network connectivity between hosts.
  5. Use GCS for checkpoints not local disk:
    output_dir: gs://my-bucket/output/
    
Symptoms: Slow step times with multiple hosts.Solutions:
  1. Ensure enough data files - need more files than hosts:
    # If 8 hosts, need at least 8+ TFRecord files
    no_records_per_shard=10  # Reduce to create more files
    
  2. Use GCS for data storage not local:
    train_data_dir: gs://my-bucket/dataset/
    
  3. Enable data shuffling:
    enable_data_shuffling: True
    

Getting help

If you’re still experiencing issues:
  1. Check the logs for detailed error messages
  2. Enable profiler to identify performance bottlenecks
  3. Search GitHub issues: https://github.com/AI-Hypercomputer/maxdiffusion/issues
  4. File a bug report with:
    • Complete error message and stack trace
    • Hardware type (TPU v5p, v6e, GPU model)
    • MaxDiffusion version and commit hash
    • Full command or config used
    • Steps to reproduce

Build docs developers (and LLMs) love