Skip to main content

Overview

The train_wan.py script trains Wan text-to-video models using the MaxDiffusion framework. It supports Wan 2.1 and 2.2 models with advanced features including gradient checkpointing, flash attention, and synthetic data generation.

Command-line usage

python src/maxdiffusion/train_wan.py src/maxdiffusion/configs/base_wan_14b.yml [OPTIONS]

Basic example (single VM)

export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true \
--xla_tpu_enable_ag_backward_pipelining=true \
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
--xla_tpu_data_parallel_opt_different_sized_ops=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536'

HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention='flash' \
  weights_dtype=bfloat16 \
  activations_dtype=bfloat16 \
  guidance_scale=5.0 \
  flow_shift=5.0 \
  fps=16 \
  skip_jax_distributed_system=False \
  run_name=wan-training \
  output_dir=gs://your-bucket/wan/ \
  train_data_dir=gs://your-bucket/dataset/train/ \
  load_tfrecord_cached=True \
  height=1280 \
  width=720 \
  num_frames=81 \
  num_inference_steps=50 \
  max_train_steps=1000 \
  per_device_batch_size=0.25 \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=4 \
  ici_tensor_parallelism=1

Configuration parameters

Run configuration

run_name
string
required
Name for this training run. Used for organizing outputs and metrics.
base_output_directory
string
GCS bucket path for outputs (e.g., gs://my-bucket/). Checkpoints and metrics will be saved here.
output_dir
string
default:"sdxl-model-finetuned"
Local or GCS directory for model outputs.

Model configuration

pretrained_model_name_or_path
string
default:"Wan-AI/Wan2.1-T2V-14B-Diffusers"
HuggingFace model identifier or local path to pretrained Wan model.
model_name
string
default:"wan2.1"
Wan model version identifier.
model_type
string
default:"T2V"
Model type. Options: T2V (text-to-video), I2V (image-to-video).
wan_transformer_pretrained_model_name_or_path
string
default:""
Override the transformer from pretrained_model_name_or_path with a different checkpoint.
weights_dtype
string
default:"bfloat16"
Data type for model weights. Options: float32, bfloat16.
activations_dtype
string
default:"bfloat16"
Data type for layer activations. Options: float32, bfloat16.
precision
string
default:"DEFAULT"
JAX precision for matmul and conv operations. Options: DEFAULT, HIGH, HIGHEST.
from_pt
boolean
default:true
Load weights from PyTorch format.
scan_layers
boolean
default:true
Use jax.lax.scan for transformer layers to reduce compilation memory.
replicate_vae
boolean
default:false
Replicate VAE across devices instead of using model’s sharding annotations.

Attention configuration

attention
string
default:"flash"
Attention mechanism to use. Options: dot_product, flash, cudnn_flash_te, ring.
split_head_dim
boolean
default:true
Whether to split attention head dimensions for sharding.
attention_sharding_uniform
boolean
default:true
Use uniform sequence sharding for both self and cross attention.
mask_padding_tokens
boolean
default:true
Pass segment IDs to attention to avoid attending to padding tokens.
flash_min_seq_length
number
default:0
Minimum sequence length to use flash attention.
flash_block_sizes
object
Custom block sizes for flash attention. Default for v5p:
flash_block_sizes: {
  "block_q" : 512,
  "block_kv_compute" : 512,
  "block_kv" : 512,
  "block_q_dkv" : 512,
  "block_kv_dkv" : 512,
  "block_kv_dkv_compute" : 512,
  "block_q_dq" : 512,
  "block_kv_dq" : 512,
  "use_fused_bwd_kernel": False
}
For v6e (Trillium), use larger blocks like 3024.

Training hyperparameters

learning_rate
number
Learning rate for the optimizer.
scale_lr
boolean
default:false
Scale learning rate by the number of GPUs/TPUs and batch size.
max_train_steps
number
default:1500
Maximum number of training steps. Takes priority over num_train_epochs.
num_train_epochs
number
default:1
Number of training epochs.
per_device_batch_size
number
Batch size per device. Can be fractional (e.g., 0.25) but must multiply to a whole number across devices.
global_batch_size
number
default:0
If non-zero, override global batch size. If not evenly divisible by device count, use FSDP sharding.
warmup_steps_fraction
number
Fraction of total steps to use for learning rate warmup.
seed
number
default:0
Random seed for reproducibility.

Optimizer parameters

adam_b1
number
Exponential decay rate for first moment estimates.
adam_b2
number
Exponential decay rate for second moment estimates.
adam_eps
number
Small constant for numerical stability.
adam_weight_decay
number
default:0
Weight decay coefficient for AdamW optimizer.
max_grad_norm
number
Maximum gradient norm for gradient clipping.
save_optimizer
boolean
default:false
Save optimizer state in checkpoints.

Gradient checkpointing (remat)

remat_policy
string
default:"NONE"
Gradient checkpoint policy. Options:
  • NONE: No gradient checkpointing
  • FULL: Full gradient checkpointing (minimum memory)
  • MATMUL_WITHOUT_BATCH: Checkpoint matmul ops without batch dimension
  • OFFLOAD_MATMUL_WITHOUT_BATCH: Same as above but offload instead of recompute
  • CUSTOM: Use custom names from names_which_can_be_saved and names_which_can_be_offloaded
names_which_can_be_saved
array
default:[]
For CUSTOM remat policy: list of operation names to save. Options include: attn_output, query_proj, key_proj, value_proj, xq_out, xk_out, ffn_activation.
names_which_can_be_offloaded
array
default:[]
For CUSTOM remat policy: list of operation names to offload.
dropout
number
Dropout rate for training.

Dataset configuration

dataset_name
string
default:"diffusers/pokemon-gpt4-captions"
HuggingFace dataset identifier.
train_data_dir
string
required
GCS or local path to TFRecord training data.
train_split
string
default:"train"
Dataset split to use for training.
dataset_type
string
default:"tfrecord"
Dataset format type. Options: tfrecord, hf, tf, grain, synthetic.
load_tfrecord_cached
boolean
default:true
Load preprocessed TFRecord files.
dataset_save_location
string
default:""
Path to save or load cached dataset.
image_column
string
default:"image"
Name of the image column in the dataset.
caption_column
string
default:"text"
Name of the caption column in the dataset.
resolution
number
default:1024
Spatial resolution for video frames.
height
number
default:480
Video frame height.
width
number
default:832
Video frame width.
num_frames
number
default:81
Number of frames in video.
enable_data_shuffling
boolean
default:true
Shuffle the dataset during training.

Synthetic data configuration

synthetic_num_samples
number
For dataset_type='synthetic': number of synthetic samples. Set to null for infinite samples.
synthetic_override_height
number
Override height for synthetic data.
synthetic_override_width
number
Override width for synthetic data.
synthetic_override_num_frames
number
Override number of frames for synthetic data.
synthetic_override_max_sequence_length
number
Override max sequence length for synthetic data.

Parallelism and sharding

hardware
string
default:"tpu"
Hardware type. Options: tpu, gpu.
mesh_axes
array
default:["data","fsdp","context","tensor"]
Logical mesh axes for parallelism.
dcn_data_parallelism
number
default:1
Data parallelism across DCN.
dcn_fsdp_parallelism
number
default:1
FSDP parallelism across DCN.
dcn_context_parallelism
number
Context (sequence) parallelism across DCN. -1 for auto-sharding.
dcn_tensor_parallelism
number
default:1
Tensor parallelism across DCN.
ici_data_parallelism
number
default:1
Data parallelism within ICI.
ici_fsdp_parallelism
number
default:1
FSDP parallelism within ICI. In Wan 2.1, this axis is used for sequence parallelism.
ici_context_parallelism
number
Context parallelism within ICI. -1 for auto-sharding.
ici_tensor_parallelism
number
default:1
Tensor (head) parallelism within ICI. For Wan 2.1, must evenly divide 40 heads.

Checkpointing

checkpoint_every
number
Save checkpoint every N samples. -1 disables checkpointing.
checkpoint_dir
string
default:""
Directory to save checkpoints.
enable_single_replica_ckpt_restoring
boolean
default:false
Enable one replica to read checkpoint and broadcast to others.

Evaluation

eval_every
number
Evaluate model every N steps. -1 disables evaluation during training.
eval_data_dir
string
default:""
Path to evaluation dataset with timesteps.
enable_generate_video_for_eval
boolean
default:false
Generate videos during evaluation. Increases TPU memory usage.
enable_eval_timesteps
boolean
default:false
Enable timestep-based evaluation as described in Scaling Rectified Flow Transformers paper.
timesteps_list
array
List of timesteps to evaluate.
num_eval_samples
number
default:420
Number of samples to use for evaluation.
eval_max_number_of_samples_in_bucket
number
default:60
Maximum samples per timestep bucket for evaluation.
enable_ssim
boolean
default:false
Enable SSIM metric calculation during evaluation.

Metrics and logging

write_metrics
boolean
default:true
Save metrics such as loss and TFLOPS.
gcs_metrics
boolean
default:false
Write metrics to GCS.
log_period
number
default:100
Tensorboard flush period.

Profiling

enable_profiler
boolean
default:false
Enable JAX profiler.
skip_first_n_steps_for_profiler
number
default:5
Skip first N steps when profiling to exclude compilation.
profiler_steps
number
default:10
Number of steps to profile.
enable_jax_named_scopes
boolean
default:false
Enable JAX named scopes for detailed profiling and debugging.

Generation parameters

prompt
string
Prompt for test video generation during training.
negative_prompt
string
Negative prompt for guidance.
guidance_scale
number
Classifier-free guidance scale.
flow_shift
number
Flow shift parameter for Wan models.
num_inference_steps
number
default:30
Number of denoising steps for test generation.
fps
number
default:16
Frames per second for generated videos.

LoRA configuration

enable_lora
boolean
default:false
Enable LoRA adapters for training or inference.
lora_config
object
Configuration for LoRA adapters.
  • rank: List of LoRA ranks
  • lora_model_name_or_path: List of LoRA model paths or HuggingFace repos
  • weight_name: List of weight filenames
  • adapter_name: List of adapter names
  • scale: List of scaling factors
  • from_pt: List of booleans indicating PyTorch format

Quantization

quantization
string
default:""
Quantization configuration.
use_qwix_quantization
boolean
default:false
Enable qwix quantization for Wan transformer.
weight_quantization_calibration_method
string
default:"absmax"
Calibration method for weight quantization.
act_quantization_calibration_method
string
default:"absmax"
Calibration method for activation quantization.
bwd_quantization_calibration_method
string
default:"absmax"
Calibration method for backward pass quantization.
qwix_module_path
string
default:".*"
Regex pattern for modules to quantize with qwix.

TFRecord creation

tfrecords_dir
string
default:""
Output directory for TFRecord creation.
no_records_per_shard
number
default:0
Number of records per TFRecord shard.

Expected outputs

Training output

During training, you’ll see output like:
***** Running training *****
Instantaneous batch size per device = 0.25
Total train batch size (w. parallel & distributed) = 1
Total optimization steps = 1000
Calculated TFLOPs per pass: 4893.2719
completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270
completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144

Checkpoints

When checkpoint_every is set, checkpoints are saved to:
{base_output_directory}/{run_name}/checkpoints/

Metrics

Training metrics are saved to:
{base_output_directory}/{run_name}/metrics/
Metrics include:
  • Training loss
  • Learning rate
  • TFLOPS per device
  • Step time
  • Evaluation metrics (if enabled)

Tensorboard logs

View with:
tensorboard --logdir=gs://your-bucket/{run_name}/tensorboard/

Dataset preparation

Before training, convert your dataset to TFRecords:
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  train_data_dir=$HF_DATASET_DIR \
  tfrecords_dir=$TFRECORDS_DATASET_DIR/train \
  no_records_per_shard=10 \
  enable_eval_timesteps=False
For evaluation dataset with timesteps:
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  train_data_dir=$HF_DATASET_DIR \
  tfrecords_dir=$TFRECORDS_DATASET_DIR/eval \
  no_records_per_shard=10 \
  enable_eval_timesteps=True

Build docs developers (and LLMs) love