Supported models
MaxDiffusion supports the following models for inference:Stable Diffusion
SD 2 base and SD 2.1 for 512×512 image generation
Stable Diffusion XL
SDXL for high-quality 1024×1024 images with dual text encoders
Flux
Flux dev and schnell variants with optimized flash attention
Wan
Wan 2.1 and 2.2 for text-to-video and image-to-video generation
LTX Video
LTX-Video for high-quality video generation with conditioning
ControlNet
Conditional generation with ControlNet for SD 1.4 and SDXL
Key features
Sharding strategies
MaxDiffusion supports multiple parallelism strategies for efficient inference:- Data parallelism (DDP): Replicate the model across devices and process different prompts in parallel
- FSDP: Shard model parameters across devices to fit larger models in memory
- Context parallelism: Split sequence dimension for handling longer context
Trillium optimizations
TPU v6e (Trillium) benefits from optimized flash attention block sizes. Enable by uncommenting theflash_block_sizes configuration in model config files:
Encoder offloading
For models with large text encoders (like Flux), offload encoders to keep the transformer and VAE in HBM:Precision control
All models use bfloat16 by default for optimal performance on TPUs:- Activations: bfloat16
- Weights: bfloat16
- Latents: float32 for numerical stability
Common parameters
All inference scripts accept these common parameters:| Parameter | Description | Default |
|---|---|---|
prompt | Text prompt for generation | Required |
negative_prompt | Negative prompt to avoid concepts | Empty string |
num_inference_steps | Number of denoising steps | Model-specific |
guidance_scale | Classifier-free guidance strength | 7.5 |
per_device_batch_size | Batch size per device | 1 |
seed | Random seed for reproducibility | 0 |
output_dir | Directory for saving outputs | /tmp/ |
jax_cache_dir | JAX compilation cache directory | Required |
Performance tips
- Use flash attention: Set
attention="flash"for 2-4x speedup on supported hardware - Enable HF transfer: Set
HF_HUB_ENABLE_HF_TRANSFER=1for faster model downloads - Cache compilations: Use
jax_cache_dirto avoid recompiling on subsequent runs - Optimize batch size: Increase
per_device_batch_sizeto maximize hardware utilization - Use async collectives: Set LIBTPU_INIT_ARGS for better communication overlap on TPUs
Next steps
Stable Diffusion XL
Generate high-quality images with SDXL
Flux
Fast inference with Flux dev and schnell
Wan video generation
Create videos with Wan 2.1 and 2.2
LoRA loading
Load custom LoRA adapters