Skip to main content
The WanPipeline is a text-to-video and image-to-video generation pipeline using the WAN (World Action Network) transformer.

WanPipeline

Located in src/maxdiffusion/pipelines/wan/wan_pipeline.py:191

Components

tokenizer
AutoTokenizer
UMT5 tokenizer for text encoding
text_encoder
UMT5EncoderModel
UMT5 text encoder (google/umt5-xxl)
vae
AutoencoderKLWan
Video VAE for encoding/decoding video frames
vae_cache
AutoencoderKLWanCache
Cache for VAE intermediate activations
scheduler
FlaxUniPCMultistepScheduler
UniPC multistep scheduler for diffusion
scheduler_state
UniPCMultistepSchedulerState
Scheduler state
mesh
Mesh
JAX mesh for distributed computation
image_processor
CLIPImageProcessor
CLIP image processor (for I2V models)
image_encoder
FlaxCLIPVisionModel
CLIP image encoder (for I2V models)

Methods

encode_prompt

Encodes text prompts using T5 encoder. Parameters:
prompt
str | List[str]
The prompt or prompts to guide video generation
negative_prompt
str | List[str]
Negative prompts for classifier-free guidance
num_videos_per_prompt
int
default:"1"
Number of videos to generate per prompt
max_sequence_length
int
default:"226"
Maximum sequence length for text encoder
Returns:
prompt_embeds
jnp.ndarray
Encoded prompt embeddings
negative_prompt_embeds
jnp.ndarray
Encoded negative prompt embeddings

encode_image

Encodes images using CLIP encoder (for WAN 2.1 I2V). Parameters:
image
PIL.Image | List[PIL.Image]
Input image(s) for image-to-video generation
num_videos_per_prompt
int
default:"1"
Number of videos per prompt
Returns:
image_embeds
jnp.ndarray
CLIP image embeddings

prepare_latents

Prepares initial latent tensors for video generation. Parameters:
batch_size
int
Batch size
vae_scale_factor_temporal
int
Temporal downsampling factor
vae_scale_factor_spatial
int
Spatial downsampling factor
height
int
default:"480"
Height of generated videos
width
int
default:"832"
Width of generated videos
num_frames
int
default:"81"
Number of frames in the video
num_channels_latents
int
default:"16"
Number of channels in latent space
Returns:
latents
jnp.ndarray
Random latent tensors for video generation

prepare_latents_i2v_base

Prepares latent conditioning for image-to-video generation. Parameters:
image
jax.Array
Input image tensor
num_frames
int
Number of frames to generate
dtype
jnp.dtype
Data type for latents
last_image
jax.Array
Optional last frame for video bookending
Returns:
latent_condition
jnp.ndarray
VAE encoded latents of the image(s)
video_condition
jnp.ndarray
Input to the VAE

Class methods

load_transformer

Loads the WAN transformer model with sharding. Parameters:
devices_array
np.array
Array of devices
mesh
Mesh
JAX mesh
rngs
nnx.Rngs
Random number generators
config
HyperParameters
Configuration
restored_checkpoint
dict
Optional checkpoint to restore from
Returns:
wan_transformer
WanModel
Loaded and sharded transformer model

load_vae

Loads the video VAE with sharding. Returns:
wan_vae
AutoencoderKLWan
Loaded video VAE
vae_cache
AutoencoderKLWanCache
VAE cache for intermediate activations

quantize_transformer

Quantizes the transformer using Qwix. Parameters:
config
HyperParameters
Configuration with quantization settings
model
WanModel
Model to quantize
Returns:
quantized_model
WanModel
Quantized model

Key features

  • Video generation: Generates high-quality videos from text or images
  • Temporal consistency: Uses 3D attention for temporal coherence
  • Flow matching scheduler: Uses flow matching for efficient sampling
  • Multiple model variants: Supports WAN 2.1 and 2.2 architectures
  • I2V conditioning: Image-to-video via CLIP embeddings (2.1) or VAE latents (2.2)
  • Quantization support: FP8/INT8 quantization via Qwix

Model variants

WAN 2.1

  • Uses CLIP image encoder for I2V conditioning
  • Image embeddings are passed to transformer

WAN 2.2

  • Uses VAE latent conditioning for I2V
  • No CLIP image encoder required

Example usage

from maxdiffusion import WanPipeline

# Load pipeline components
config = HyperParameters(
    pretrained_model_name_or_path="genmo/mochi-1-preview",
    model_type="T2V",
    resolution=480,
    num_frames=81
)

devices_array = create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

# Initialize pipeline
components = WanPipeline._create_common_components(config)
pipeline = WanPipeline(**components, config=config)

# Load transformer
wan_transformer = WanPipeline.load_transformer(
    devices_array, mesh, rngs, config
)

# Generate video
videos = pipeline(
    prompt="a cat playing piano",
    num_inference_steps=50,
    height=480,
    width=832,
    num_frames=81
)

Build docs developers (and LLMs) love