Overview
MaxDiffusion supports training and inference for multiple state-of-the-art diffusion models. Each model has specific requirements, configurations, and optimizations for different hardware accelerators.
Stable Diffusion models
MaxDiffusion supports the complete Stable Diffusion family with training and inference capabilities.
Stable Diffusion 1.x
Architecture : U-Net based text-to-image model with CLIP text encoder
Supported features :
Training and inference
Dreambooth fine-tuning
ControlNet conditioning
Configuration :
pretrained_model_name_or_path : 'runwayml/stable-diffusion-v1-5'
resolution : 512
attention : 'flash'
Training example :
python -m src.maxdiffusion.train src/maxdiffusion/configs/base14.yml \
run_name="my_run" \
output_dir="gs://your-bucket/" \
attention=flash
Stable Diffusion 2.x
Architecture : Enhanced U-Net with improved text encoder
Supported models :
SD 2 base (512x512)
SD 2.1 (768x768)
Configuration :
pretrained_model_name_or_path : 'stabilityai/stable-diffusion-2-base'
resolution : 512
weights_dtype : 'float32'
activations_dtype : 'float32'
pretrained_model_name_or_path : 'stabilityai/stable-diffusion-2-1'
resolution : 768
attention : 'flash'
Stable Diffusion XL
Architecture : Larger U-Net with dual text encoders (CLIP + OpenCLIP)
Supported features :
Training and inference
Multi-host distributed training
LoRA fine-tuning
SDXL Lightning (fast inference)
ControlNet
Configuration :
pretrained_model_name_or_path : 'stabilityai/stable-diffusion-xl-base-1.0'
resolution : 1024
weights_dtype : 'float32'
activations_dtype : 'bfloat16'
attention : 'dot_product'
Parallelism configuration :
mesh_axes : [ 'data' , 'fsdp' , 'context' , 'tensor' ]
ici_data_parallelism : -1 # Auto-shard
ici_fsdp_parallelism : 1
ici_context_parallelism : 1
ici_tensor_parallelism : 1
Training example :
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
run_name="my_xl_run" \
output_dir="gs://your-bucket/" \
per_device_batch_size= 1
GPU training with fused attention :
NVTE_FUSED_ATTN = 1 python -m src.maxdiffusion.train_sdxl \
src/maxdiffusion/configs/base_xl.yml \
hardware=gpu \
attention="cudnn_flash_te" \
per_device_batch_size= 1
Flux models
Architecture : Transformer-based architecture with flow matching
Variants :
Flux.1-dev : 28-step high-quality generation
Flux.1-schnell : 4-step fast generation
Key features :
Transformer architecture (not U-Net)
T5-XXL and CLIP text encoders
Flow matching instead of DDPM
LoRA support
GPU acceleration with Transformer Engine
Configuration
pretrained_model_name_or_path : 'black-forest-labs/FLUX.1-dev'
flux_name : "flux-dev"
max_sequence_length : 512
weights_dtype : 'bfloat16'
activations_dtype : 'bfloat16'
attention : 'flash'
offload_encoders : True # Save memory
Parallelism strategies
Data parallelism (DDP)
Model parallelism (FSDP)
ici_data_parallelism : 4
ici_fsdp_parallelism : 1
ici_context_parallelism : 1
ici_tensor_parallelism : 1
ici_data_parallelism : 1
ici_fsdp_parallelism : -1 # Auto-shard across devices
offload_encoders : False
Training example
python src/maxdiffusion/train_flux.py \
src/maxdiffusion/configs/base_flux_dev.yml \
run_name="test-flux-train" \
output_dir="gs://your-bucket/" \
jax_cache_dir="/tmp/jax_cache"
Inference examples
Flux-dev
Flux-schnell
GPU with fused attention
python src/maxdiffusion/generate_flux.py \
src/maxdiffusion/configs/base_flux_dev.yml \
prompt="photograph of an electronics chip in the shape of a race car" \
per_device_batch_size= 1
Model Accelerator Strategy Batch Size Steps Time (secs) Flux-dev v4-8 DDP 4 28 23 Flux-schnell v4-8 DDP 4 4 2.2 Flux-dev v6e-4 DDP 4 28 5.5 Flux-schnell v6e-4 DDP 4 4 0.8 Flux-schnell v6e-4 FSDP 4 4 1.2
Wan models
Architecture : Video diffusion transformers using Rectified Flow
Supported models :
Wan2.1 : 14B parameter text-to-video and image-to-video
Wan2.2 : 27B parameter enhanced quality
Key features :
Text-to-video and image-to-video generation
Up to 121 frames at 720p/1080p
Sequence parallelism for long videos
LoRA support
Evaluation during training
Configuration
pretrained_model_name_or_path : 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
model_name : wan2.1
model_type : 'T2V'
weights_dtype : 'bfloat16'
activations_dtype : 'bfloat16'
attention : 'flash'
flash_min_seq_length : 0
mask_padding_tokens : True
Parallelism for video models
Wan models use specialized parallelism for long sequences:
# Sequence parallelism (splits sequence across devices)
ici_fsdp_parallelism : 4 # Used for sequence parallelism
# Head parallelism (40 heads must divide evenly)
ici_tensor_parallelism : 1 # Used for head parallelism
# Context parallelism for even longer sequences
ici_context_parallelism : 2
# Data parallelism
ici_data_parallelism : 1
# Fractional batch sizes supported
per_device_batch_size : 0.25 # 0.25 * 4 devices = 1 global batch
In Wan2.1, ici_fsdp_parallelism is used for sequence parallelism and ici_tensor_parallelism for head parallelism (40 heads must be divisible by this value).
Training example
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
height= 1280 \
width= 720 \
num_frames= 81 \
per_device_batch_size= 0.25 \
ici_data_parallelism= 1 \
ici_fsdp_parallelism= 4 \
ici_tensor_parallelism= 1 \
remat_policy='HIDDEN_STATE_WITH_OFFLOAD'
Inference examples
Wan2.1 T2V
Wan2.1 I2V
Wan2.2 T2V
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention="flash" \
num_frames= 81 \
width= 1280 \
height= 720
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_i2v_14b.yml \
attention="flash"
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_27b.yml \
attention="flash"
Synthetic data for benchmarking
Test different video dimensions without downloading datasets:
dataset_type : 'synthetic'
synthetic_num_samples : null # Infinite
synthetic_override_height : 720
synthetic_override_width : 1280
synthetic_override_num_frames : 85
LTX-Video
Architecture : Latent video diffusion model
Supported features :
Text-to-video generation
Image-to-video generation
Inference example :
python src/maxdiffusion/generate_ltx_video.py \
src/maxdiffusion/configs/ltx_video.yml \
output_dir="/path/to/weights" \
config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json"
LoRA support
MaxDiffusion supports Low-Rank Adaptation (LoRA) for efficient fine-tuning:
Supported models : SDXL, Flux, Wan
Single LoRA configuration :
enable_lora : True
lora_config : {
lora_model_name_or_path : [ "ByteDance/Hyper-SD" ],
weight_name : [ "Hyper-SDXL-2steps-lora.safetensors" ],
adapter_name : [ "hyper-sdxl" ],
scale : [ 0.7 ],
from_pt : [ true ]
}
Multiple LoRA loading :
lora_config : {
lora_model_name_or_path : [
"/path/to/lora1.safetensors" ,
"user/lora-repo"
],
weight_name : [
"lora1.safetensors" ,
"lora2.safetensors"
],
adapter_name : [ "style1" , "style2" ],
scale : [ 0.8 , 0.6 ],
from_pt : [ true , true ]
}
Model comparison
Model Architecture Resolution Training Inference Hardware SD 1.x U-Net 512x512 ✓ ✓ TPU, GPU SD 2.x U-Net 512-768 ✓ ✓ TPU, GPU SDXL U-Net 1024x1024 ✓ ✓ TPU, GPU Flux Transformer 1024x1024 ✓ ✓ TPU, GPU Wan2.1 Transformer 720p-1080p video ✓ ✓ TPU Wan2.2 Transformer 720p-1080p video - ✓ TPU LTX-Video Transformer Video - ✓ TPU
Next steps