Wan models enable text-to-video (T2V) and image-to-video (I2V) generation with support for multiple resolutions and frame counts.
Quick start
Attaching an external disk is recommended as Wan weights require significant storage. Follow these instructions to attach external storage.
Wan 2.1 T2V
Wan 2.1 I2V
Wan 2.2 T2V
HF_HUB_CACHE = /mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS= "--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention="flash" \
num_inference_steps= 50 \
num_frames= 81 \
width= 1280 \
height= 720 \
jax_cache_dir=gs://your-bucket/jax_cache/ \
per_device_batch_size=.125 \
ici_data_parallelism= 2 \
ici_context_parallelism= 2 \
flow_shift= 5.0 \
run_name=wan-inference-720p \
output_dir=gs://your-bucket/outputs \
fps= 16 \
flash_min_seq_length= 0 \
flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' \
seed= 118445
Model variants
Wan offers multiple model sizes and capabilities:
Model Parameters Config file Use case Wan 2.1 T2V 14B base_wan_14b.ymlText-to-video generation Wan 2.1 I2V 14B base_wan_i2v_14b.ymlImage-to-video animation Wan 2.2 T2V 27B base_wan_27b.ymlHigher quality T2V Wan 2.2 I2V 27B base_wan_i2v_27b.ymlHigher quality I2V
Parameters
Parameter Description Default promptText description of video Required negative_promptConcepts to avoid Empty num_inference_stepsDenoising steps 50 num_framesNumber of video frames 81 widthVideo width in pixels 1280 heightVideo height in pixels 720 fpsFrames per second 16 guidance_scaleCFG strength (Wan 2.1) 7.0 guidance_scale_lowLow noise CFG (Wan 2.2) 3.0 guidance_scale_highHigh noise CFG (Wan 2.2) 7.0 flow_shiftFlow matching shift parameter 5.0 per_device_batch_sizeBatch size per device 0.125 seedRandom seed 0
Sharding strategies
Wan models require advanced parallelism for efficient inference:
Data and context parallelism
ici_data_parallelism = 2 # Replicate across 2 devices
ici_context_parallelism = 2 # Split sequence across 2 devices
This configuration enables processing longer videos and larger batch sizes.
Flash attention
Custom flash attention block sizes optimize TPU memory usage:
flash_block_sizes = '{
"block_q" : 3024,
"block_kv_compute" : 1024,
"block_kv" : 2048,
"block_q_dkv": 3024,
"block_kv_dkv" : 2048,
"block_kv_dkv_compute" : 2048,
"block_q_dq" : 3024,
"block_kv_dq" : 2048
}'
LoRA support
Wan supports LoRA adapters for style customization. Supports ComfyUI and AI Toolkit formats.
Single LoRA
Create a copy of the relevant config file and update LoRA details:
HF_HUB_CACHE = /mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS= "--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_i2v_14b.yml \
jax_cache_dir=gs://your-bucket/jax_cache/ \
per_device_batch_size=.125 \
ici_data_parallelism= 2 \
ici_context_parallelism= 2 \
run_name=wan-lora-inference \
output_dir=gs://your-bucket/outputs \
seed= 118445 \
enable_lora=True
Multiple LoRAs
Wan supports loading multiple LoRA adapters simultaneously. Configure LoRA details in the config file:
enable_lora : True
lora_config :
lora_model_name_or_path : [ "path/to/lora1" , "path/to/lora2" ]
weight_name : [ "lora1.safetensors" , "lora2.safetensors" ]
rank : [ 64 , 64 ]
scale : [ 0.8 , 0.6 ]
Not all LoRA formats have been tested. If a specific LoRA doesn’t load, please report it.
LIBTPU optimizations
The LIBTPU_INIT_ARGS environment variable enables critical TPU optimizations:
LIBTPU_INIT_ARGS = "
--xla_tpu_enable_async_collective_fusion=true # Fuse collectives
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true # Fuse all-reduce ops
--xla_tpu_enable_async_collective_fusion_multiple_steps=true # Multi-step fusion
--xla_tpu_overlap_compute_collective_tc=true # Overlap compute and communication
--xla_enable_async_all_reduce=true # Async all-reduce
"
These flags significantly improve performance for multi-device inference.
Image-to-video
I2V models animate static images. Provide an image URL in the config:
image_url : "https://example.com/image.jpg"
# or local path
image_url : "/path/to/image.jpg"
The model generates video starting from the input image.
Videos are saved as MP4 files:
Format: H.264 MP4
Filename: wan_output_{seed}_{index}.mp4
Location: Specified output_dir (local or GCS)
For GCS output directories (gs://...), videos are automatically uploaded and local copies are deleted.
Implementation details
The Wan pipeline (generate_wan.py:src/maxdiffusion/generate_wan.py) implements:
Model selection
The pipeline selects the appropriate checkpointer based on model configuration (generate_wan.py:177-190):
if model_key == WAN2_1 :
if model_type == "I2V" :
checkpoint_loader = WanCheckpointerI2V_2_1( config = config)
else :
checkpoint_loader = WanCheckpointer2_1( config = config)
elif model_key == WAN2_2 :
if model_type == "I2V" :
checkpoint_loader = WanCheckpointerI2V_2_2( config = config)
else :
checkpoint_loader = WanCheckpointer2_2( config = config)
Guidance scale selection
Wan 2.1 uses single guidance scale, while Wan 2.2 uses dual guidance (generate_wan.py:87-140):
def call_pipeline ( config , pipeline , prompt , negative_prompt ):
model_key = config.model_name
if model_key == WAN2_1 :
return pipeline(
prompt = prompt,
negative_prompt = negative_prompt,
height = config.height,
width = config.width,
num_frames = config.num_frames,
num_inference_steps = config.num_inference_steps,
guidance_scale = config.guidance_scale,
)
elif model_key == WAN2_2 :
return pipeline(
prompt = prompt,
negative_prompt = negative_prompt,
height = config.height,
width = config.width,
num_frames = config.num_frames,
num_inference_steps = config.num_inference_steps,
guidance_scale_low = config.guidance_scale_low,
guidance_scale_high = config.guidance_scale_high,
)
Video export
Generated videos are exported with configurable FPS (generate_wan.py:156-162):
for i in range ( len (videos)):
video_path = f "wan_output_ { config.seed } _ { i } .mp4"
export_to_video(videos[i], video_path, fps = config.fps)
if config.output_dir.startswith( "gs://" ):
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
delete_file( f "./ { video_path } " )
The pipeline reports:
Compile time: Initial compilation duration
Generation time: Video generation time
Generation time per video: Average time per video in batch
Enable profiling with:
This generates XLA HLO profiles for performance analysis.
Next steps
Wan training Fine-tune Wan 2.1 on custom video datasets
LTX Video Alternative video generation model
Flux inference High-quality image generation
Configuration Full configuration reference