Skip to main content

Overview

MaxDiffusion supports training and inference for multiple state-of-the-art diffusion models. Each model has specific requirements, configurations, and optimizations for different hardware accelerators.

Stable Diffusion models

MaxDiffusion supports the complete Stable Diffusion family with training and inference capabilities.

Stable Diffusion 1.x

Architecture: U-Net based text-to-image model with CLIP text encoder Supported features:
  • Training and inference
  • Dreambooth fine-tuning
  • ControlNet conditioning
Configuration:
pretrained_model_name_or_path: 'runwayml/stable-diffusion-v1-5'
resolution: 512
attention: 'flash'
Training example:
python -m src.maxdiffusion.train src/maxdiffusion/configs/base14.yml \
  run_name="my_run" \
  output_dir="gs://your-bucket/" \
  attention=flash

Stable Diffusion 2.x

Architecture: Enhanced U-Net with improved text encoder Supported models:
  • SD 2 base (512x512)
  • SD 2.1 (768x768)
Configuration:
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-base'
resolution: 512
weights_dtype: 'float32'
activations_dtype: 'float32'

Stable Diffusion XL

Architecture: Larger U-Net with dual text encoders (CLIP + OpenCLIP) Supported features:
  • Training and inference
  • Multi-host distributed training
  • LoRA fine-tuning
  • SDXL Lightning (fast inference)
  • ControlNet
Configuration:
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0'
resolution: 1024
weights_dtype: 'float32'
activations_dtype: 'bfloat16'
attention: 'dot_product'
Parallelism configuration:
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
ici_data_parallelism: -1      # Auto-shard
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
Training example:
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
  run_name="my_xl_run" \
  output_dir="gs://your-bucket/" \
  per_device_batch_size=1
GPU training with fused attention:
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  hardware=gpu \
  attention="cudnn_flash_te" \
  per_device_batch_size=1

Flux models

Architecture: Transformer-based architecture with flow matching Variants:
  • Flux.1-dev: 28-step high-quality generation
  • Flux.1-schnell: 4-step fast generation
Key features:
  • Transformer architecture (not U-Net)
  • T5-XXL and CLIP text encoders
  • Flow matching instead of DDPM
  • LoRA support
  • GPU acceleration with Transformer Engine

Configuration

pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
flux_name: "flux-dev"
max_sequence_length: 512
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'
attention: 'flash'
offload_encoders: True  # Save memory

Parallelism strategies

ici_data_parallelism: 4
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1

Training example

python src/maxdiffusion/train_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  run_name="test-flux-train" \
  output_dir="gs://your-bucket/" \
  jax_cache_dir="/tmp/jax_cache"

Inference examples

python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  prompt="photograph of an electronics chip in the shape of a race car" \
  per_device_batch_size=1

Performance benchmarks

ModelAcceleratorStrategyBatch SizeStepsTime (secs)
Flux-devv4-8DDP42823
Flux-schnellv4-8DDP442.2
Flux-devv6e-4DDP4285.5
Flux-schnellv6e-4DDP440.8
Flux-schnellv6e-4FSDP441.2

Wan models

Architecture: Video diffusion transformers using Rectified Flow Supported models:
  • Wan2.1: 14B parameter text-to-video and image-to-video
  • Wan2.2: 27B parameter enhanced quality
Key features:
  • Text-to-video and image-to-video generation
  • Up to 121 frames at 720p/1080p
  • Sequence parallelism for long videos
  • LoRA support
  • Evaluation during training

Configuration

pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
model_name: wan2.1
model_type: 'T2V'
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'
attention: 'flash'
flash_min_seq_length: 0
mask_padding_tokens: True

Parallelism for video models

Wan models use specialized parallelism for long sequences:
# Sequence parallelism (splits sequence across devices)
ici_fsdp_parallelism: 4       # Used for sequence parallelism

# Head parallelism (40 heads must divide evenly)
ici_tensor_parallelism: 1     # Used for head parallelism

# Context parallelism for even longer sequences
ici_context_parallelism: 2

# Data parallelism
ici_data_parallelism: 1

# Fractional batch sizes supported
per_device_batch_size: 0.25   # 0.25 * 4 devices = 1 global batch
In Wan2.1, ici_fsdp_parallelism is used for sequence parallelism and ici_tensor_parallelism for head parallelism (40 heads must be divisible by this value).

Training example

python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention='flash' \
  height=1280 \
  width=720 \
  num_frames=81 \
  per_device_batch_size=0.25 \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=4 \
  ici_tensor_parallelism=1 \
  remat_policy='HIDDEN_STATE_WITH_OFFLOAD'

Inference examples

python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention="flash" \
  num_frames=81 \
  width=1280 \
  height=720

Synthetic data for benchmarking

Test different video dimensions without downloading datasets:
dataset_type: 'synthetic'
synthetic_num_samples: null  # Infinite
synthetic_override_height: 720
synthetic_override_width: 1280
synthetic_override_num_frames: 85

LTX-Video

Architecture: Latent video diffusion model Supported features:
  • Text-to-video generation
  • Image-to-video generation
Inference example:
python src/maxdiffusion/generate_ltx_video.py \
  src/maxdiffusion/configs/ltx_video.yml \
  output_dir="/path/to/weights" \
  config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json"

LoRA support

MaxDiffusion supports Low-Rank Adaptation (LoRA) for efficient fine-tuning: Supported models: SDXL, Flux, Wan Single LoRA configuration:
enable_lora: True
lora_config: {
  lora_model_name_or_path: ["ByteDance/Hyper-SD"],
  weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
  adapter_name: ["hyper-sdxl"],
  scale: [0.7],
  from_pt: [true]
}
Multiple LoRA loading:
lora_config: {
  lora_model_name_or_path: [
    "/path/to/lora1.safetensors",
    "user/lora-repo"
  ],
  weight_name: [
    "lora1.safetensors",
    "lora2.safetensors"
  ],
  adapter_name: ["style1", "style2"],
  scale: [0.8, 0.6],
  from_pt: [true, true]
}

Model comparison

ModelArchitectureResolutionTrainingInferenceHardware
SD 1.xU-Net512x512TPU, GPU
SD 2.xU-Net512-768TPU, GPU
SDXLU-Net1024x1024TPU, GPU
FluxTransformer1024x1024TPU, GPU
Wan2.1Transformer720p-1080p videoTPU
Wan2.2Transformer720p-1080p video-TPU
LTX-VideoTransformerVideo-TPU

Next steps

Build docs developers (and LLMs) love