Skip to main content

Overview

The generate_sdxl.py script provides inference capabilities for Stable Diffusion XL models. It supports single and multi-host deployment with JAX sharding annotations, LoRA loading, and SDXL Lightning for fast inference.

Usage

Basic command

python -m src.maxdiffusion.generate_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_run"

With LoRA

python src/maxdiffusion/generate_sdxl.py \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="test-lora" \
  output_dir=/tmp/ \
  num_inference_steps=30 \
  prompt="ultra detailed diagram blueprint of a papercut cat" \
  lora_config='{"lora_model_name_or_path" : ["path/to/lora.safetensors"], "weight_name" : ["lora.safetensors"], "adapter_name" : ["my-lora"], "scale": [0.8], "from_pt": ["true"]}'

SDXL Lightning

python -m src.maxdiffusion.generate_sdxl \
  src/maxdiffusion/configs/base_xl_lightning.yml \
  run_name="my_run" \
  lightning_repo="ByteDance/SDXL-Lightning" \
  lightning_ckpt="sdxl_lightning_4step_unet.safetensors"

Configuration parameters

Model parameters

pretrained_model_name_or_path
string
default:"stabilityai/stable-diffusion-xl-base-1.0"
Path or identifier for the pretrained SDXL model.
revision
string
default:"refs/pr/95"
Model revision to use from HuggingFace.
from_pt
boolean
default:false
Set to true to load weights from PyTorch format.

Generation parameters

prompt
string
Text prompt describing the image to generate.
negative_prompt
string
default:"purple, red"
Negative prompt to guide what should not appear in the image.
num_inference_steps
integer
default:20
Number of denoising steps. SDXL typically requires fewer steps than SD 2.x.
guidance_scale
float
Classifier-free guidance scale. SDXL works well with higher values (7-12).
guidance_rescale
float
Guidance rescale factor. Recommended value is 0.7 for best results.
do_classifier_free_guidance
boolean
default:true
Enable classifier-free guidance for better prompt adherence.
resolution
integer
default:1024
Output image resolution in pixels. SDXL is optimized for 1024x1024.

Performance parameters

weights_dtype
string
default:"float32"
Data type for model weights. Use bfloat16 on TPUv5e for optimal performance.
activations_dtype
string
default:"bfloat16"
Data type for layer activations.
attention
string
default:"dot_product"
Attention mechanism. Options: dot_product, flash, cudnn_flash_te (GPU only).
split_head_dim
boolean
default:true
Enable head dimension splitting for better performance.

LoRA parameters

lora_config
object
Configuration for loading LoRA adapters. Supports multiple LoRAs.Properties:
  • lora_model_name_or_path (array): Paths to LoRA models
  • weight_name (array): LoRA weight filenames
  • adapter_name (array): Unique names for each adapter
  • scale (array): Scale factors (0.0-1.0) for each LoRA
  • from_pt (array): Whether to load from PyTorch format
Example:
{
  "lora_model_name_or_path": ["ByteDance/Hyper-SD"],
  "weight_name": ["Hyper-SDXL-2steps-lora.safetensors"],
  "adapter_name": ["hyper-sdxl"],
  "scale": [0.7],
  "from_pt": ["true"]
}

SDXL Lightning parameters

lightning_repo
string
default:""
HuggingFace repository for SDXL Lightning weights. Example: ByteDance/SDXL-Lightning.
lightning_ckpt
string
default:""
Checkpoint filename. Options: sdxl_lightning_2step_unet.safetensors, sdxl_lightning_4step_unet.safetensors, sdxl_lightning_8step_unet.safetensors.
lightning_from_pt
boolean
default:true
Load Lightning weights from PyTorch format.

Parallelism parameters

ici_data_parallelism
integer
Data parallelism across ICI devices. Use -1 for auto-sharding.
ici_fsdp_parallelism
integer
default:1
FSDP parallelism across ICI devices.
per_device_batch_size
integer
default:2
Batch size per device. SDXL requires more memory than SD 2.x.

System parameters

run_name
string
required
Unique identifier for the inference run.
output_dir
string
default:"sdxl-model-finetuned"
Directory to save generated images.
seed
integer
default:0
Random seed for reproducible generation.
jax_cache_dir
string
default:""
Directory for JAX compilation cache.

Output

Generated images are saved as PNG files with the naming pattern image_sdxl_{i}.png.

Performance data

SDXL inference performance varies by hardware and configuration. Typical results:
  • TPU v4-8: ~20-30 seconds for 20 steps at 1024x1024
  • TPU v5e: Optimized with bfloat16 weights

Implementation details

The script:
  1. Loads SDXL pipeline with dual text encoders (CLIP ViT-L and OpenCLIP ViT-bigG)
  2. Optionally loads LoRA adapters or Lightning weights
  3. Creates inference states with proper sharding
  4. Encodes prompts with both text encoders and generates pooled embeddings
  5. Adds time embeddings for SDXL’s conditioning
  6. Runs denoising loop with optional classifier-free guidance
  7. Decodes latents and saves images
See source: ~/workspace/source/src/maxdiffusion/generate_sdxl.py:1

Build docs developers (and LLMs) love