Skip to main content

Overview

The generate_wan.py script provides inference for Wan models (Wan2.1 and Wan2.2), supporting both text-to-video (T2V) and image-to-video (I2V) generation. Wan models are state-of-the-art video diffusion models capable of generating high-quality, temporally consistent videos.

Usage

Wan2.1 text-to-video

HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention="flash" \
  num_inference_steps=50 \
  num_frames=81 \
  width=1280 \
  height=720 \
  per_device_batch_size=.125 \
  ici_data_parallelism=2 \
  ici_context_parallelism=2 \
  flow_shift=5.0 \
  fps=16 \
  seed=118445

Wan2.1 image-to-video

python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_i2v_14b.yml \
  image_url="path/to/image.jpg" \
  num_inference_steps=50 \
  num_frames=81

Wan2.2 text-to-video

python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_27b.yml \
  num_inference_steps=50

With LoRA

python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  enable_lora=True \
  # LoRA config in YAML file

Configuration parameters

Model parameters

pretrained_model_name_or_path
string
default:"Wan-AI/Wan2.1-T2V-14B-Diffusers"
Model path. Options:
  • Wan-AI/Wan2.1-T2V-14B-Diffusers (Wan2.1 T2V)
  • Wan-AI/Wan2.1-I2V-14B-Diffusers (Wan2.1 I2V)
  • Wan-AI/Wan2.2-T2V-27B-Diffusers (Wan2.2 T2V)
  • Wan-AI/Wan2.2-I2V-27B-Diffusers (Wan2.2 I2V)
model_name
string
default:"wan2.1"
Model variant. Options: wan2.1, wan2.2.
model_type
string
default:"T2V"
Pipeline type. Options: T2V (text-to-video), I2V (image-to-video).
wan_transformer_pretrained_model_name_or_path
string
default:""
Override path for transformer weights.

Generation parameters

prompt
string
required
Text prompt describing the video to generate.Example:
"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon."
negative_prompt
string
Negative prompt to guide what should not appear in the video.Example:
"Bright tones, overexposed, static, blurred details, worst quality, low quality, ugly, incomplete, poorly drawn"
image_url
string
default:""
Path to input image for I2V generation. Required when model_type=I2V.
height
integer
default:480
Video height in pixels. Common values: 480, 720, 1080.
width
integer
default:832
Video width in pixels. Common values: 832, 1280, 1920.
num_frames
integer
default:81
Number of frames to generate. Must be compatible with temporal compression.
num_inference_steps
integer
default:30
Number of denoising steps. More steps = better quality but slower.
fps
integer
default:16
Frames per second for output video.

Wan2.1 guidance parameters

guidance_scale
float
Classifier-free guidance scale for Wan2.1 models.

Wan2.2 guidance parameters

guidance_scale_low
float
Low-frequency guidance scale for Wan2.2 models (dual guidance system).
guidance_scale_high
float
High-frequency guidance scale for Wan2.2 models (dual guidance system).

Flow parameters

flow_shift
float
Flow shift parameter for rectified flow. Controls the noise schedule.

Performance parameters

weights_dtype
string
default:"bfloat16"
Data type for model weights. Use bfloat16 for optimal performance.
activations_dtype
string
default:"bfloat16"
Data type for activations.
attention
string
default:"flash"
Attention implementation. Options: flash, cudnn_flash_te (GPU), ring, dot_product.
scan_layers
boolean
default:true
Use jax.lax.scan for transformer layers to reduce memory.
replicate_vae
boolean
default:false
Replicate VAE across devices instead of sharding.
flash_block_sizes
object
Flash attention block sizes. Different optimal values for v5p and v6e:v6e (Trillium):
flash_block_sizes: {
  "block_q": 3024,
  "block_kv_compute": 1024,
  "block_kv": 2048,
  "block_q_dkv": 3024,
  "block_kv_dkv": 2048,
  "block_kv_dkv_compute": 1024,
  "block_q_dq": 3024,
  "block_kv_dq": 2048
}
flash_min_seq_length
integer
default:0
Minimum sequence length for flash attention.

Parallelism parameters

ici_data_parallelism
integer
default:1
Data parallelism across ICI devices.
ici_fsdp_parallelism
integer
default:1
FSDP parallelism. Used for sequence parallelism in Wan2.1. Values of 2 or 4 work best. The sequence length is padded to be evenly divisible.
ici_context_parallelism
integer
Context parallelism. Recommended for auto-sharding.
ici_tensor_parallelism
integer
default:1
Tensor parallelism. Used for head parallelism in Wan2.1. Must evenly divide 40 (number of attention heads).
per_device_batch_size
float
Batch size per device. Can be fractional (e.g., 0.25, 0.125) but must result in whole number when multiplied by device count.

LoRA parameters

enable_lora
boolean
default:false
Enable LoRA loading for inference.
lora_config
object
LoRA configuration. Supports ComfyUI and AI Toolkit formats.Wan2.1 properties:
  • rank (array): LoRA rank values
  • lora_model_name_or_path (array): Paths to LoRA models
  • weight_name (array): Weight filenames
  • adapter_name (array): Adapter names
  • scale (array): Scale factors
Wan2.2 properties:
  • rank (array): LoRA rank values
  • high_noise_weight_name (array): High-noise transformer weights
  • low_noise_weight_name (array): Low-noise transformer weights
  • scale (array): Scale factors

System parameters

run_name
string
required
Unique run identifier.
output_dir
string
default:""
Output directory for videos. Supports GCS paths (gs://).
seed
integer
default:0
Random seed for reproducibility.
jax_cache_dir
string
default:""
JAX compilation cache directory.
enable_profiler
boolean
default:false
Enable performance profiling.

Output

Videos are saved as MP4 files with the naming pattern wan_output_{seed}_{i}.mp4. Files can be automatically uploaded to GCS if output_dir starts with gs://.

Implementation details

Wan models use:
  • NNX-based transformer architecture
  • Rectified flow for video generation
  • 3D VAE for video compression
  • Dual guidance system (Wan2.2 only)
  • Sequence and head parallelism for efficient training/inference
  • Optional LoRA support for customization

Performance notes

  • Fractional batch sizes allow fine-grained memory optimization
  • Sequence parallelism (ici_fsdp_parallelism) shards the sequence dimension
  • Head parallelism (ici_tensor_parallelism) must divide 40 evenly
  • Use external disk for model weights (models are large)
  • GPU users: cudnn_te_flash attention recommended with ici_fsdp_batch_parallelism
See source: ~/workspace/source/src/maxdiffusion/generate_wan.py:1

Build docs developers (and LLMs) love