Skip to main content
MaxDiffusion provides high-performance inference for latent diffusion models on Google Cloud TPUs and GPUs. All models leverage JAX for compilation and XLA for hardware-specific optimizations.

Supported models

MaxDiffusion supports the following models for inference:

Stable Diffusion

SD 2 base and SD 2.1 for 512×512 image generation

Stable Diffusion XL

SDXL for high-quality 1024×1024 images with dual text encoders

Flux

Flux dev and schnell variants with optimized flash attention

Wan

Wan 2.1 and 2.2 for text-to-video and image-to-video generation

LTX Video

LTX-Video for high-quality video generation with conditioning

ControlNet

Conditional generation with ControlNet for SD 1.4 and SDXL

Key features

Sharding strategies

MaxDiffusion supports multiple parallelism strategies for efficient inference:
  • Data parallelism (DDP): Replicate the model across devices and process different prompts in parallel
  • FSDP: Shard model parameters across devices to fit larger models in memory
  • Context parallelism: Split sequence dimension for handling longer context
Configure sharding with parameters:
ici_data_parallelism=4      # Number of data parallel devices
ici_fsdp_parallelism=-1     # Fully shard model parameters
ici_context_parallelism=2   # Context parallel degree

Trillium optimizations

TPU v6e (Trillium) benefits from optimized flash attention block sizes. Enable by uncommenting the flash_block_sizes configuration in model config files:

Encoder offloading

For models with large text encoders (like Flux), offload encoders to keep the transformer and VAE in HBM:
offload_encoders=False  # Keep all components in HBM

Precision control

All models use bfloat16 by default for optimal performance on TPUs:
  • Activations: bfloat16
  • Weights: bfloat16
  • Latents: float32 for numerical stability

Common parameters

All inference scripts accept these common parameters:
ParameterDescriptionDefault
promptText prompt for generationRequired
negative_promptNegative prompt to avoid conceptsEmpty string
num_inference_stepsNumber of denoising stepsModel-specific
guidance_scaleClassifier-free guidance strength7.5
per_device_batch_sizeBatch size per device1
seedRandom seed for reproducibility0
output_dirDirectory for saving outputs/tmp/
jax_cache_dirJAX compilation cache directoryRequired

Performance tips

  1. Use flash attention: Set attention="flash" for 2-4x speedup on supported hardware
  2. Enable HF transfer: Set HF_HUB_ENABLE_HF_TRANSFER=1 for faster model downloads
  3. Cache compilations: Use jax_cache_dir to avoid recompiling on subsequent runs
  4. Optimize batch size: Increase per_device_batch_size to maximize hardware utilization
  5. Use async collectives: Set LIBTPU_INIT_ARGS for better communication overlap on TPUs

Next steps

Stable Diffusion XL

Generate high-quality images with SDXL

Flux

Fast inference with Flux dev and schnell

Wan video generation

Create videos with Wan 2.1 and 2.2

LoRA loading

Load custom LoRA adapters

Build docs developers (and LLMs) love