Overview
Thegenerate_sdxl.py script provides inference capabilities for Stable Diffusion XL models. It supports single and multi-host deployment with JAX sharding annotations, LoRA loading, and SDXL Lightning for fast inference.
Usage
Basic command
With LoRA
SDXL Lightning
Configuration parameters
Model parameters
Path or identifier for the pretrained SDXL model.
Model revision to use from HuggingFace.
Set to
true to load weights from PyTorch format.Generation parameters
Text prompt describing the image to generate.
Negative prompt to guide what should not appear in the image.
Number of denoising steps. SDXL typically requires fewer steps than SD 2.x.
Classifier-free guidance scale. SDXL works well with higher values (7-12).
Guidance rescale factor. Recommended value is 0.7 for best results.
Enable classifier-free guidance for better prompt adherence.
Output image resolution in pixels. SDXL is optimized for 1024x1024.
Performance parameters
Data type for model weights. Use
bfloat16 on TPUv5e for optimal performance.Data type for layer activations.
Attention mechanism. Options:
dot_product, flash, cudnn_flash_te (GPU only).Enable head dimension splitting for better performance.
LoRA parameters
Configuration for loading LoRA adapters. Supports multiple LoRAs.Properties:
lora_model_name_or_path(array): Paths to LoRA modelsweight_name(array): LoRA weight filenamesadapter_name(array): Unique names for each adapterscale(array): Scale factors (0.0-1.0) for each LoRAfrom_pt(array): Whether to load from PyTorch format
SDXL Lightning parameters
HuggingFace repository for SDXL Lightning weights. Example:
ByteDance/SDXL-Lightning.Checkpoint filename. Options:
sdxl_lightning_2step_unet.safetensors, sdxl_lightning_4step_unet.safetensors, sdxl_lightning_8step_unet.safetensors.Load Lightning weights from PyTorch format.
Parallelism parameters
Data parallelism across ICI devices. Use -1 for auto-sharding.
FSDP parallelism across ICI devices.
Batch size per device. SDXL requires more memory than SD 2.x.
System parameters
Unique identifier for the inference run.
Directory to save generated images.
Random seed for reproducible generation.
Directory for JAX compilation cache.
Output
Generated images are saved as PNG files with the naming patternimage_sdxl_{i}.png.
Performance data
SDXL inference performance varies by hardware and configuration. Typical results:- TPU v4-8: ~20-30 seconds for 20 steps at 1024x1024
- TPU v5e: Optimized with bfloat16 weights
Implementation details
The script:- Loads SDXL pipeline with dual text encoders (CLIP ViT-L and OpenCLIP ViT-bigG)
- Optionally loads LoRA adapters or Lightning weights
- Creates inference states with proper sharding
- Encodes prompts with both text encoders and generates pooled embeddings
- Adds time embeddings for SDXL’s conditioning
- Runs denoising loop with optional classifier-free guidance
- Decodes latents and saves images
~/workspace/source/src/maxdiffusion/generate_sdxl.py:1