Overview
Thetrain_wan.py script trains Wan text-to-video models using the MaxDiffusion framework. It supports Wan 2.1 and 2.2 models with advanced features including gradient checkpointing, flash attention, and synthetic data generation.
Command-line usage
Basic example (single VM)
Configuration parameters
Run configuration
Name for this training run. Used for organizing outputs and metrics.
GCS bucket path for outputs (e.g.,
gs://my-bucket/). Checkpoints and metrics will be saved here.Local or GCS directory for model outputs.
Model configuration
HuggingFace model identifier or local path to pretrained Wan model.
Wan model version identifier.
Model type. Options:
T2V (text-to-video), I2V (image-to-video).Override the transformer from
pretrained_model_name_or_path with a different checkpoint.Data type for model weights. Options:
float32, bfloat16.Data type for layer activations. Options:
float32, bfloat16.JAX precision for matmul and conv operations. Options:
DEFAULT, HIGH, HIGHEST.Load weights from PyTorch format.
Use jax.lax.scan for transformer layers to reduce compilation memory.
Replicate VAE across devices instead of using model’s sharding annotations.
Attention configuration
Attention mechanism to use. Options:
dot_product, flash, cudnn_flash_te, ring.Whether to split attention head dimensions for sharding.
Use uniform sequence sharding for both self and cross attention.
Pass segment IDs to attention to avoid attending to padding tokens.
Minimum sequence length to use flash attention.
Custom block sizes for flash attention. Default for v5p:For v6e (Trillium), use larger blocks like 3024.
Training hyperparameters
Learning rate for the optimizer.
Scale learning rate by the number of GPUs/TPUs and batch size.
Maximum number of training steps. Takes priority over
num_train_epochs.Number of training epochs.
Batch size per device. Can be fractional (e.g., 0.25) but must multiply to a whole number across devices.
If non-zero, override global batch size. If not evenly divisible by device count, use FSDP sharding.
Fraction of total steps to use for learning rate warmup.
Random seed for reproducibility.
Optimizer parameters
Exponential decay rate for first moment estimates.
Exponential decay rate for second moment estimates.
Small constant for numerical stability.
Weight decay coefficient for AdamW optimizer.
Maximum gradient norm for gradient clipping.
Save optimizer state in checkpoints.
Gradient checkpointing (remat)
Gradient checkpoint policy. Options:
NONE: No gradient checkpointingFULL: Full gradient checkpointing (minimum memory)MATMUL_WITHOUT_BATCH: Checkpoint matmul ops without batch dimensionOFFLOAD_MATMUL_WITHOUT_BATCH: Same as above but offload instead of recomputeCUSTOM: Use custom names fromnames_which_can_be_savedandnames_which_can_be_offloaded
For CUSTOM remat policy: list of operation names to save. Options include:
attn_output, query_proj, key_proj, value_proj, xq_out, xk_out, ffn_activation.For CUSTOM remat policy: list of operation names to offload.
Dropout rate for training.
Dataset configuration
HuggingFace dataset identifier.
GCS or local path to TFRecord training data.
Dataset split to use for training.
Dataset format type. Options:
tfrecord, hf, tf, grain, synthetic.Load preprocessed TFRecord files.
Path to save or load cached dataset.
Name of the image column in the dataset.
Name of the caption column in the dataset.
Spatial resolution for video frames.
Video frame height.
Video frame width.
Number of frames in video.
Shuffle the dataset during training.
Synthetic data configuration
For
dataset_type='synthetic': number of synthetic samples. Set to null for infinite samples.Override height for synthetic data.
Override width for synthetic data.
Override number of frames for synthetic data.
Override max sequence length for synthetic data.
Parallelism and sharding
Hardware type. Options:
tpu, gpu.Logical mesh axes for parallelism.
Data parallelism across DCN.
FSDP parallelism across DCN.
Context (sequence) parallelism across DCN. -1 for auto-sharding.
Tensor parallelism across DCN.
Data parallelism within ICI.
FSDP parallelism within ICI. In Wan 2.1, this axis is used for sequence parallelism.
Context parallelism within ICI. -1 for auto-sharding.
Tensor (head) parallelism within ICI. For Wan 2.1, must evenly divide 40 heads.
Checkpointing
Save checkpoint every N samples. -1 disables checkpointing.
Directory to save checkpoints.
Enable one replica to read checkpoint and broadcast to others.
Evaluation
Evaluate model every N steps. -1 disables evaluation during training.
Path to evaluation dataset with timesteps.
Generate videos during evaluation. Increases TPU memory usage.
Enable timestep-based evaluation as described in Scaling Rectified Flow Transformers paper.
List of timesteps to evaluate.
Number of samples to use for evaluation.
Maximum samples per timestep bucket for evaluation.
Enable SSIM metric calculation during evaluation.
Metrics and logging
Save metrics such as loss and TFLOPS.
Write metrics to GCS.
Tensorboard flush period.
Profiling
Enable JAX profiler.
Skip first N steps when profiling to exclude compilation.
Number of steps to profile.
Enable JAX named scopes for detailed profiling and debugging.
Generation parameters
Prompt for test video generation during training.
Negative prompt for guidance.
Classifier-free guidance scale.
Flow shift parameter for Wan models.
Number of denoising steps for test generation.
Frames per second for generated videos.
LoRA configuration
Enable LoRA adapters for training or inference.
Configuration for LoRA adapters.
rank: List of LoRA rankslora_model_name_or_path: List of LoRA model paths or HuggingFace reposweight_name: List of weight filenamesadapter_name: List of adapter namesscale: List of scaling factorsfrom_pt: List of booleans indicating PyTorch format
Quantization
Quantization configuration.
Enable qwix quantization for Wan transformer.
Calibration method for weight quantization.
Calibration method for activation quantization.
Calibration method for backward pass quantization.
Regex pattern for modules to quantize with qwix.
TFRecord creation
Output directory for TFRecord creation.
Number of records per TFRecord shard.
Expected outputs
Training output
During training, you’ll see output like:Checkpoints
Whencheckpoint_every is set, checkpoints are saved to:
Metrics
Training metrics are saved to:- Training loss
- Learning rate
- TFLOPS per device
- Step time
- Evaluation metrics (if enabled)