Wan 2.1 is a video generation model supporting both text-to-video (T2V) and image-to-video (I2V) generation. This guide covers training the 14B parameter T2V model.
Attaching an external disk is recommended as model weights take up significant disk space. Tested using v5p-8 with a 500GB disk.
Dataset preparation
This example uses the PusaV1 dataset.
Set up directories
export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/
export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1
Download the dataset
huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR
Convert to TFRecords (training)
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
src/maxdiffusion/configs/base_wan_14b.yml \
train_data_dir=$HF_DATASET_DIR \
tfrecords_dir=$TFRECORDS_DATASET_DIR/train \
no_records_per_shard=10 \
enable_eval_timesteps=False
Check progress:ls -ll $TFRECORDS_DATASET_DIR/train
Convert to TFRecords (evaluation)
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
src/maxdiffusion/configs/base_wan_14b.yml \
train_data_dir=$HF_DATASET_DIR \
tfrecords_dir=$TFRECORDS_DATASET_DIR/eval \
no_records_per_shard=10 \
enable_eval_timesteps=True
This creates the first 420 samples with timestep fields for validation (as described in Scaling Rectified Flow Transformers).Remove duplicates from training set
Delete the first 420 samples from training to avoid overlap:printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' rm
Verify deletion:printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' echo
Clean up empty files
In some cases, an empty file is created in eval_timesteps:rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec
Verify directory structure
You should see:wan_tfr_dataset_pusa_v1/
├── train/
└── eval_timesteps/
Training on a single VM
Upload data to GCS
BUCKET_NAME=my-bucket
gcloud storage cp --recursive $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}
Set environment variables
RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
Configure LIBTPU_INIT_ARGS
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true \
--xla_tpu_enable_ag_backward_pipelining=true \
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
--xla_tpu_data_parallel_opt_different_sized_ops=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_async_all_to_all=true \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
--xla_max_concurrent_host_send_recv=100 \
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
--xla_latency_hiding_scheduler_rerun=2 \
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
--xla_tpu_assign_all_reduce_scatter_layout=true'
Run training
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
weights_dtype=bfloat16 \
activations_dtype=bfloat16 \
guidance_scale=5.0 \
flow_shift=5.0 \
fps=16 \
skip_jax_distributed_system=False \
run_name=${RUN_NAME} \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
load_tfrecord_cached=True \
height=1280 \
width=720 \
num_frames=81 \
num_inference_steps=50 \
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
max_train_steps=1000 \
enable_profiler=True \
dataset_save_location=${SAVE_DATASET_DIR} \
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
flash_min_seq_length=0 \
seed=$RANDOM \
skip_first_n_steps_for_profiler=3 \
profiler_steps=3 \
per_device_batch_size=0.25 \
ici_data_parallelism=1 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=1
Expected output
***** Running training *****
Instantaneous batch size per device = 0.25
Total train batch size (w. parallel & distributed) = 1
Total optimization steps = 1000
Calculated TFLOPs per pass: 4893.2719
completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270
To see full metrics 'tensorboard --logdir=gs://bucket/wan/run-name/tensorboard/'
completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144
completed step: 2, seconds: 36.014, TFLOP/s/device: 135.871, loss: 0.210
Training with XPK
For large-scale training on v5p-256:
Set environment variables
RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
Set LIBTPU_INIT_ARGS
Use the same LIBTPU_INIT_ARGS as single VM training (see above).
Create XPK workload
python3 ~/xpk/xpk.py workload create \
--cluster=$CLUSTER_NAME \
--project=$PROJECT \
--zone=$ZONE \
--device-type=$DEVICE_TYPE \
--num-slices=1 \
--command=" \
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
weights_dtype=bfloat16 \
activations_dtype=bfloat16 \
guidance_scale=5.0 \
flow_shift=5.0 \
fps=16 \
skip_jax_distributed_system=False \
run_name=${RUN_NAME} \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
load_tfrecord_cached=True \
height=1280 \
width=720 \
num_frames=81 \
num_inference_steps=50 \
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
enable_profiler=True \
dataset_save_location=${SAVE_DATASET_DIR} \
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
flash_min_seq_length=0 \
seed=$RANDOM \
skip_first_n_steps_for_profiler=3 \
profiler_steps=3 \
per_device_batch_size=0.25 \
ici_data_parallelism=32 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=1 \
max_train_steps=5000 \
eval_every=100 \
eval_data_dir=${EVAL_DATA_DIR} \
enable_generate_video_for_eval=True" \
--base-docker-image=${IMAGE_DIR} \
--enable-debug-logs \
--workload=${RUN_NAME} \
--priority=medium \
--max-restarts=0
Configuration
The base config is located at src/maxdiffusion/configs/base_wan_14b.yml.
Key parameters
| Parameter | Default | Description |
|---|
pretrained_model_name_or_path | Wan-AI/Wan2.1-T2V-14B-Diffusers | Base Wan model |
model_name | wan2.1 | Model variant |
model_type | T2V | T2V or I2V |
weights_dtype | bfloat16 | Weight precision |
activations_dtype | bfloat16 | Activation precision |
attention | flash | Attention mechanism |
height | 1280 | Video height |
width | 720 | Video width |
num_frames | 81 | Number of frames |
fps | 16 | Frames per second |
per_device_batch_size | 1.0 | Batch size (can be fractional) |
learning_rate | 1.e-5 | Initial learning rate |
max_train_steps | 1500 | Maximum training steps |
Important notes
Fractional batch sizes: per_device_batch_size can be fractional but must result in a whole number when multiplied by the number of devices. Example: 0.25 × 4 devices = 1 effective global batch size.
Parallelism axes: In Wan 2.1, ici_fsdp_parallelism is used for sequence parallelism, and ici_tensor_parallelism is used for head parallelism. The model has 40 heads, so ici_tensor_parallelism must evenly divide 40.
Parallelism configuration
# Sequence parallelism via ici_fsdp_parallelism
ici_data_parallelism: 1
ici_fsdp_parallelism: 4 # Sequence parallelism (try 2 or 4)
ici_tensor_parallelism: 1 # Head parallelism (must divide 40)
ici_context_parallelism: -1 # Auto-shard
# Multi-host parallelism
dcn_data_parallelism: 1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1 # Auto-shard
dcn_tensor_parallelism: 1
Gradient checkpointing (remat)
Wan supports multiple remat policies:
NONE - No gradient checkpointing (maximum memory usage)
FULL - Full gradient checkpointing (minimum memory usage)
MATMUL_WITHOUT_BATCH - Checkpoint linear/matmul operations except batch dimension
OFFLOAD_MATMUL_WITHOUT_BATCH - Same as above but offload instead of recompute
HIDDEN_STATE_WITH_OFFLOAD - Offload hidden states
CUSTOM - Custom policy (set names_which_can_be_saved and names_which_can_be_offloaded)
remat_policy: "HIDDEN_STATE_WITH_OFFLOAD"
Flash attention configuration
Default block sizes for v5p:
flash_block_sizes:
block_q: 512
block_kv_compute: 512
block_kv: 512
block_q_dkv: 512
block_kv_dkv: 512
block_kv_dkv_compute: 512
block_q_dq: 512
block_kv_dq: 512
use_fused_bwd_kernel: False
For v6e (Trillium), use larger block sizes:
flash_block_sizes:
block_q: 3024
block_kv_compute: 1024
block_kv: 2048
block_q_dkv: 3024
block_kv_dkv: 2048
block_kv_dkv_compute: 1024
block_q_dq: 3024
block_kv_dq: 2048
use_fused_bwd_kernel: False
Evaluation configuration
Enable evaluation during training:
eval_every: 100 # Evaluate every 100 steps (-1 to disable)
eval_data_dir: "gs://bucket/eval_timesteps/"
enable_generate_video_for_eval: True # Increases TPU memory usage
eval_max_number_of_samples_in_bucket: 60
Timesteps for evaluation
timesteps_list: [125, 250, 375, 500, 625, 750, 875]
num_eval_samples: 420
GPU-specific configuration
For GPU usage, install cudnn_te_flash attention kernel for optimal performance.
Batch parallelism
GPUs support batch parallelism via ici_fsdp_batch_parallelism:
ici_fsdp_batch_parallelism: 2 # Does not support fractional batch sizes
Combine with sequence parallelism:
ici_fsdp_batch_parallelism: 2
ici_fsdp_parallelism: 2 # For fractional batch sizes
Padding is not currently supported for cudnn_te_flash attention. The sequence length must be divisible by the number of devices in ici_fsdp_parallelism.
Synthetic data for benchmarking
Benchmark training performance without downloading datasets:
dataset_type: 'synthetic'
synthetic_num_samples: null # null for infinite samples
# Override data dimensions:
synthetic_override_height: 720
synthetic_override_width: 1280
synthetic_override_num_frames: 85
synthetic_override_max_sequence_length: 512
synthetic_override_text_embed_dim: 4096
synthetic_override_num_channels_latents: 16
synthetic_override_vae_scale_factor_spatial: 8
synthetic_override_vae_scale_factor_temporal: 4
LoRA support
Wan supports LoRA fine-tuning:
enable_lora: True
lora_config:
rank: [64]
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"]
weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"]
adapter_name: ["wan21-distill-lora"]
scale: [1.0]
from_pt: []
Quantization (experimental)
Quantization support is experimental. Use with caution.
use_qwix_quantization: True
weight_quantization_calibration_method: "absmax"
act_quantization_calibration_method: "absmax"
bwd_quantization_calibration_method: "absmax"
qwix_module_path: ".*"
Monitoring
View training metrics:
tensorboard --logdir=gs://bucket/wan/run-name/tensorboard/
Best practices
- Attach external disk: Wan weights are large (14B parameters)
- Use fractional batch sizes: Set
per_device_batch_size=0.25 for memory efficiency
- Enable remat: Use
HIDDEN_STATE_WITH_OFFLOAD for balanced memory/speed
- Sequence parallelism: Set
ici_fsdp_parallelism=4 for best performance
- Enable evaluation: Set
eval_every=100 to track training quality
- Monitor TFLOP/s: Target >100 TFLOP/s/device after warmup