Skip to main content
Wan models enable text-to-video (T2V) and image-to-video (I2V) generation with support for multiple resolutions and frame counts.

Quick start

Attaching an external disk is recommended as Wan weights require significant storage. Follow these instructions to attach external storage.
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention="flash" \
  num_inference_steps=50 \
  num_frames=81 \
  width=1280 \
  height=720 \
  jax_cache_dir=gs://your-bucket/jax_cache/ \
  per_device_batch_size=.125 \
  ici_data_parallelism=2 \
  ici_context_parallelism=2 \
  flow_shift=5.0 \
  run_name=wan-inference-720p \
  output_dir=gs://your-bucket/outputs \
  fps=16 \
  flash_min_seq_length=0 \
  flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' \
  seed=118445

Model variants

Wan offers multiple model sizes and capabilities:
ModelParametersConfig fileUse case
Wan 2.1 T2V14Bbase_wan_14b.ymlText-to-video generation
Wan 2.1 I2V14Bbase_wan_i2v_14b.ymlImage-to-video animation
Wan 2.2 T2V27Bbase_wan_27b.ymlHigher quality T2V
Wan 2.2 I2V27Bbase_wan_i2v_27b.ymlHigher quality I2V

Parameters

ParameterDescriptionDefault
promptText description of videoRequired
negative_promptConcepts to avoidEmpty
num_inference_stepsDenoising steps50
num_framesNumber of video frames81
widthVideo width in pixels1280
heightVideo height in pixels720
fpsFrames per second16
guidance_scaleCFG strength (Wan 2.1)7.0
guidance_scale_lowLow noise CFG (Wan 2.2)3.0
guidance_scale_highHigh noise CFG (Wan 2.2)7.0
flow_shiftFlow matching shift parameter5.0
per_device_batch_sizeBatch size per device0.125
seedRandom seed0

Sharding strategies

Wan models require advanced parallelism for efficient inference:

Data and context parallelism

ici_data_parallelism=2        # Replicate across 2 devices
ici_context_parallelism=2     # Split sequence across 2 devices
This configuration enables processing longer videos and larger batch sizes.

Flash attention

Custom flash attention block sizes optimize TPU memory usage:
flash_block_sizes='{
  "block_q" : 3024, 
  "block_kv_compute" : 1024, 
  "block_kv" : 2048, 
  "block_q_dkv": 3024, 
  "block_kv_dkv" : 2048, 
  "block_kv_dkv_compute" : 2048, 
  "block_q_dq" : 3024, 
  "block_kv_dq" : 2048
}'

LoRA support

Wan supports LoRA adapters for style customization. Supports ComfyUI and AI Toolkit formats.

Single LoRA

Create a copy of the relevant config file and update LoRA details:
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_i2v_14b.yml \
  jax_cache_dir=gs://your-bucket/jax_cache/ \
  per_device_batch_size=.125 \
  ici_data_parallelism=2 \
  ici_context_parallelism=2 \
  run_name=wan-lora-inference \
  output_dir=gs://your-bucket/outputs \
  seed=118445 \
  enable_lora=True

Multiple LoRAs

Wan supports loading multiple LoRA adapters simultaneously. Configure LoRA details in the config file:
enable_lora: True
lora_config:
  lora_model_name_or_path: ["path/to/lora1", "path/to/lora2"]
  weight_name: ["lora1.safetensors", "lora2.safetensors"]
  rank: [64, 64]
  scale: [0.8, 0.6]
Not all LoRA formats have been tested. If a specific LoRA doesn’t load, please report it.

LIBTPU optimizations

The LIBTPU_INIT_ARGS environment variable enables critical TPU optimizations:
LIBTPU_INIT_ARGS="
  --xla_tpu_enable_async_collective_fusion=true          # Fuse collectives
  --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true  # Fuse all-reduce ops
  --xla_tpu_enable_async_collective_fusion_multiple_steps=true   # Multi-step fusion
  --xla_tpu_overlap_compute_collective_tc=true           # Overlap compute and communication
  --xla_enable_async_all_reduce=true                     # Async all-reduce
"
These flags significantly improve performance for multi-device inference.

Image-to-video

I2V models animate static images. Provide an image URL in the config:
image_url: "https://example.com/image.jpg"
# or local path
image_url: "/path/to/image.jpg"
The model generates video starting from the input image.

Output format

Videos are saved as MP4 files:
  • Format: H.264 MP4
  • Filename: wan_output_{seed}_{index}.mp4
  • Location: Specified output_dir (local or GCS)
For GCS output directories (gs://...), videos are automatically uploaded and local copies are deleted.

Implementation details

The Wan pipeline (generate_wan.py:src/maxdiffusion/generate_wan.py) implements:

Model selection

The pipeline selects the appropriate checkpointer based on model configuration (generate_wan.py:177-190):
if model_key == WAN2_1:
  if model_type == "I2V":
    checkpoint_loader = WanCheckpointerI2V_2_1(config=config)
  else:
    checkpoint_loader = WanCheckpointer2_1(config=config)
elif model_key == WAN2_2:
  if model_type == "I2V":
    checkpoint_loader = WanCheckpointerI2V_2_2(config=config)
  else:
    checkpoint_loader = WanCheckpointer2_2(config=config)

Guidance scale selection

Wan 2.1 uses single guidance scale, while Wan 2.2 uses dual guidance (generate_wan.py:87-140):
def call_pipeline(config, pipeline, prompt, negative_prompt):
  model_key = config.model_name
  if model_key == WAN2_1:
    return pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=config.height,
        width=config.width,
        num_frames=config.num_frames,
        num_inference_steps=config.num_inference_steps,
        guidance_scale=config.guidance_scale,
    )
  elif model_key == WAN2_2:
    return pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=config.height,
        width=config.width,
        num_frames=config.num_frames,
        num_inference_steps=config.num_inference_steps,
        guidance_scale_low=config.guidance_scale_low,
        guidance_scale_high=config.guidance_scale_high,
    )

Video export

Generated videos are exported with configurable FPS (generate_wan.py:156-162):
for i in range(len(videos)):
  video_path = f"wan_output_{config.seed}_{i}.mp4"
  export_to_video(videos[i], video_path, fps=config.fps)
  if config.output_dir.startswith("gs://"):
    upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
    delete_file(f"./{video_path}")

Performance monitoring

The pipeline reports:
  • Compile time: Initial compilation duration
  • Generation time: Video generation time
  • Generation time per video: Average time per video in batch
Enable profiling with:
enable_profiler=True
This generates XLA HLO profiles for performance analysis.

Next steps

Wan training

Fine-tune Wan 2.1 on custom video datasets

LTX Video

Alternative video generation model

Flux inference

High-quality image generation

Configuration

Full configuration reference

Build docs developers (and LLMs) love