Skip to main content
Wan 2.1 is a video generation model supporting both text-to-video (T2V) and image-to-video (I2V) generation. This guide covers training the 14B parameter T2V model.
Attaching an external disk is recommended as model weights take up significant disk space. Tested using v5p-8 with a 500GB disk.

Dataset preparation

This example uses the PusaV1 dataset.
1

Set up directories

export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/
export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1
2

Download the dataset

huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR
3

Convert to TFRecords (training)

python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  train_data_dir=$HF_DATASET_DIR \
  tfrecords_dir=$TFRECORDS_DATASET_DIR/train \
  no_records_per_shard=10 \
  enable_eval_timesteps=False
Check progress:
ls -ll $TFRECORDS_DATASET_DIR/train
4

Convert to TFRecords (evaluation)

python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  train_data_dir=$HF_DATASET_DIR \
  tfrecords_dir=$TFRECORDS_DATASET_DIR/eval \
  no_records_per_shard=10 \
  enable_eval_timesteps=True
This creates the first 420 samples with timestep fields for validation (as described in Scaling Rectified Flow Transformers).
5

Remove duplicates from training set

Delete the first 420 samples from training to avoid overlap:
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' rm
Verify deletion:
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' echo
6

Clean up empty files

In some cases, an empty file is created in eval_timesteps:
rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec
7

Verify directory structure

You should see:
wan_tfr_dataset_pusa_v1/
├── train/
└── eval_timesteps/

Training on a single VM

1

Upload data to GCS

BUCKET_NAME=my-bucket
gcloud storage cp --recursive $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}
2

Set environment variables

RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
3

Configure LIBTPU_INIT_ARGS

export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true \
--xla_tpu_enable_ag_backward_pipelining=true \
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
--xla_tpu_data_parallel_opt_different_sized_ops=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_async_all_to_all=true \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
--xla_max_concurrent_host_send_recv=100 \
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
--xla_latency_hiding_scheduler_rerun=2 \
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
--xla_tpu_assign_all_reduce_scatter_layout=true'
4

Run training

HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
weights_dtype=bfloat16 \
activations_dtype=bfloat16 \
guidance_scale=5.0 \
flow_shift=5.0 \
fps=16 \
skip_jax_distributed_system=False \
run_name=${RUN_NAME} \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
load_tfrecord_cached=True \
height=1280 \
width=720 \
num_frames=81 \
num_inference_steps=50 \
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
max_train_steps=1000 \
enable_profiler=True \
dataset_save_location=${SAVE_DATASET_DIR} \
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
flash_min_seq_length=0 \
seed=$RANDOM \
skip_first_n_steps_for_profiler=3 \
profiler_steps=3 \
per_device_batch_size=0.25 \
ici_data_parallelism=1 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=1

Expected output

***** Running training *****
Instantaneous batch size per device = 0.25
Total train batch size (w. parallel & distributed) = 1
Total optimization steps = 1000
Calculated TFLOPs per pass: 4893.2719
completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270
To see full metrics 'tensorboard --logdir=gs://bucket/wan/run-name/tensorboard/'
completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144
completed step: 2, seconds: 36.014, TFLOP/s/device: 135.871, loss: 0.210

Training with XPK

For large-scale training on v5p-256:
1

Set environment variables

RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
2

Set LIBTPU_INIT_ARGS

Use the same LIBTPU_INIT_ARGS as single VM training (see above).
3

Create XPK workload

python3 ~/xpk/xpk.py workload create \
--cluster=$CLUSTER_NAME \
--project=$PROJECT \
--zone=$ZONE \
--device-type=$DEVICE_TYPE \
--num-slices=1 \
--command=" \
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
weights_dtype=bfloat16 \
activations_dtype=bfloat16 \
guidance_scale=5.0 \
flow_shift=5.0 \
fps=16 \
skip_jax_distributed_system=False \
run_name=${RUN_NAME} \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
load_tfrecord_cached=True \
height=1280 \
width=720 \
num_frames=81 \
num_inference_steps=50 \
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
enable_profiler=True \
dataset_save_location=${SAVE_DATASET_DIR} \
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
flash_min_seq_length=0 \
seed=$RANDOM \
skip_first_n_steps_for_profiler=3 \
profiler_steps=3 \
per_device_batch_size=0.25 \
ici_data_parallelism=32 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=1 \
max_train_steps=5000 \
eval_every=100 \
eval_data_dir=${EVAL_DATA_DIR} \
enable_generate_video_for_eval=True" \
--base-docker-image=${IMAGE_DIR} \
--enable-debug-logs \
--workload=${RUN_NAME} \
--priority=medium \
--max-restarts=0

Configuration

The base config is located at src/maxdiffusion/configs/base_wan_14b.yml.

Key parameters

ParameterDefaultDescription
pretrained_model_name_or_pathWan-AI/Wan2.1-T2V-14B-DiffusersBase Wan model
model_namewan2.1Model variant
model_typeT2VT2V or I2V
weights_dtypebfloat16Weight precision
activations_dtypebfloat16Activation precision
attentionflashAttention mechanism
height1280Video height
width720Video width
num_frames81Number of frames
fps16Frames per second
per_device_batch_size1.0Batch size (can be fractional)
learning_rate1.e-5Initial learning rate
max_train_steps1500Maximum training steps

Important notes

Fractional batch sizes: per_device_batch_size can be fractional but must result in a whole number when multiplied by the number of devices. Example: 0.25 × 4 devices = 1 effective global batch size.
Parallelism axes: In Wan 2.1, ici_fsdp_parallelism is used for sequence parallelism, and ici_tensor_parallelism is used for head parallelism. The model has 40 heads, so ici_tensor_parallelism must evenly divide 40.

Parallelism configuration

# Sequence parallelism via ici_fsdp_parallelism
ici_data_parallelism: 1
ici_fsdp_parallelism: 4  # Sequence parallelism (try 2 or 4)
ici_tensor_parallelism: 1  # Head parallelism (must divide 40)
ici_context_parallelism: -1  # Auto-shard

# Multi-host parallelism
dcn_data_parallelism: 1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1  # Auto-shard
dcn_tensor_parallelism: 1

Gradient checkpointing (remat)

Wan supports multiple remat policies:
  • NONE - No gradient checkpointing (maximum memory usage)
  • FULL - Full gradient checkpointing (minimum memory usage)
  • MATMUL_WITHOUT_BATCH - Checkpoint linear/matmul operations except batch dimension
  • OFFLOAD_MATMUL_WITHOUT_BATCH - Same as above but offload instead of recompute
  • HIDDEN_STATE_WITH_OFFLOAD - Offload hidden states
  • CUSTOM - Custom policy (set names_which_can_be_saved and names_which_can_be_offloaded)
remat_policy: "HIDDEN_STATE_WITH_OFFLOAD"

Flash attention configuration

Default block sizes for v5p:
flash_block_sizes:
  block_q: 512
  block_kv_compute: 512
  block_kv: 512
  block_q_dkv: 512
  block_kv_dkv: 512
  block_kv_dkv_compute: 512
  block_q_dq: 512
  block_kv_dq: 512
  use_fused_bwd_kernel: False
For v6e (Trillium), use larger block sizes:
flash_block_sizes:
  block_q: 3024
  block_kv_compute: 1024
  block_kv: 2048
  block_q_dkv: 3024
  block_kv_dkv: 2048
  block_kv_dkv_compute: 1024
  block_q_dq: 3024
  block_kv_dq: 2048
  use_fused_bwd_kernel: False

Evaluation configuration

Enable evaluation during training:
eval_every: 100  # Evaluate every 100 steps (-1 to disable)
eval_data_dir: "gs://bucket/eval_timesteps/"
enable_generate_video_for_eval: True  # Increases TPU memory usage
eval_max_number_of_samples_in_bucket: 60

Timesteps for evaluation

timesteps_list: [125, 250, 375, 500, 625, 750, 875]
num_eval_samples: 420

GPU-specific configuration

For GPU usage, install cudnn_te_flash attention kernel for optimal performance.

Batch parallelism

GPUs support batch parallelism via ici_fsdp_batch_parallelism:
ici_fsdp_batch_parallelism: 2  # Does not support fractional batch sizes
Combine with sequence parallelism:
ici_fsdp_batch_parallelism: 2
ici_fsdp_parallelism: 2  # For fractional batch sizes
Padding is not currently supported for cudnn_te_flash attention. The sequence length must be divisible by the number of devices in ici_fsdp_parallelism.

Synthetic data for benchmarking

Benchmark training performance without downloading datasets:
dataset_type: 'synthetic'
synthetic_num_samples: null  # null for infinite samples

# Override data dimensions:
synthetic_override_height: 720
synthetic_override_width: 1280
synthetic_override_num_frames: 85
synthetic_override_max_sequence_length: 512
synthetic_override_text_embed_dim: 4096
synthetic_override_num_channels_latents: 16
synthetic_override_vae_scale_factor_spatial: 8
synthetic_override_vae_scale_factor_temporal: 4

LoRA support

Wan supports LoRA fine-tuning:
enable_lora: True
lora_config:
  rank: [64]
  lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"]
  weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"]
  adapter_name: ["wan21-distill-lora"]
  scale: [1.0]
  from_pt: []

Quantization (experimental)

Quantization support is experimental. Use with caution.
use_qwix_quantization: True
weight_quantization_calibration_method: "absmax"
act_quantization_calibration_method: "absmax"
bwd_quantization_calibration_method: "absmax"
qwix_module_path: ".*"

Monitoring

View training metrics:
tensorboard --logdir=gs://bucket/wan/run-name/tensorboard/

Best practices

  1. Attach external disk: Wan weights are large (14B parameters)
  2. Use fractional batch sizes: Set per_device_batch_size=0.25 for memory efficiency
  3. Enable remat: Use HIDDEN_STATE_WITH_OFFLOAD for balanced memory/speed
  4. Sequence parallelism: Set ici_fsdp_parallelism=4 for best performance
  5. Enable evaluation: Set eval_every=100 to track training quality
  6. Monitor TFLOP/s: Target >100 TFLOP/s/device after warmup

Build docs developers (and LLMs) love