Skip to main content
Stable Diffusion XL generates 1024×1024 images with dual text encoders for improved prompt understanding and image quality.

Quick start

python -m src.maxdiffusion.generate_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_run"

Architecture

SDXL uses an improved architecture:
  • UNet: Larger model with more cross-attention layers
  • Text encoders: Dual encoders (OpenCLIP ViT-G/14 + CLIP ViT-L/14)
  • Pooled embeddings: Additional conditioning from text encoder pooled output
  • Time conditioning: Resolution and crop coordinates for better control
  • VAE: Higher quality decoder for 1024×1024 generation

Configuration

Customize generation parameters:
python -m src.maxdiffusion.generate_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="custom_sdxl" \
  prompt="a photograph of a cat wearing a hat riding a skateboard in a park" \
  negative_prompt="blurry, distorted, low quality" \
  num_inference_steps=30 \
  guidance_scale=7.5 \
  per_device_batch_size=1 \
  resolution=1024 \
  seed=42 \
  output_dir=/tmp/ \
  jax_cache_dir=/tmp/cache_dir/

Parameters

ParameterDescriptionDefault
promptText description of desired imageRequired
negative_promptConcepts to avoidEmpty
num_inference_stepsDenoising steps30
guidance_scaleCFG strength7.5
guidance_rescaleNoise rescale factor0.0
do_classifier_free_guidanceEnable CFGTrue
resolutionImage height and width1024
per_device_batch_sizeImages per device1
seedRandom seed0

SDXL Lightning

SDXL Lightning enables few-step generation (2-8 steps) with minimal quality loss.

4-step generation

python src/maxdiffusion/generate_sdxl.py \
  src/maxdiffusion/configs/base_xl_lightning.yml \
  run_name="lightning_4step" \
  lightning_repo="ByteDance/SDXL-Lightning" \
  lightning_ckpt="sdxl_lightning_4step_unet.safetensors"

2-step generation

For ultra-fast generation, use 2-step Lightning with classifier-free guidance disabled:
python src/maxdiffusion/generate_sdxl.py \
  src/maxdiffusion/configs/base_xl_lightning.yml \
  run_name="lightning_2step" \
  lightning_repo="ByteDance/SDXL-Lightning" \
  lightning_ckpt="sdxl_lightning_2step_unet.safetensors" \
  num_inference_steps=2 \
  do_classifier_free_guidance=False

LoRA support

Hyper-SDXL LoRA

Hyper-SDXL enables 2-step generation with LoRA:
python src/maxdiffusion/generate_sdxl.py \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="hyper_sdxl" \
  output_dir=/tmp/ \
  jax_cache_dir=/tmp/cache_dir/ \
  num_inference_steps=2 \
  do_classifier_free_guidance=False \
  prompt="a photograph of a cat wearing a hat riding a skateboard in a park." \
  per_device_batch_size=1 \
  pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" \
  from_pt=True \
  revision=main \
  diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' \
  lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}'

Multiple LoRA loading

Load multiple LoRA adapters simultaneously:
python src/maxdiffusion/generate_sdxl.py \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="multi_lora" \
  output_dir=/tmp/ \
  jax_cache_dir=/tmp/cache_dir/ \
  num_inference_steps=30 \
  do_classifier_free_guidance=True \
  prompt="ultra detailed diagram blueprint of a papercut Sitting MaineCoon cat, wide canvas, ampereart, electrical diagram, bl3uprint, papercut" \
  per_device_batch_size=1 \
  diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' \
  lora_config='{"lora_model_name_or_path" : ["/path/to/blueprintify-sd-xl-10.safetensors","TheLastBen/Papercut_SDXL"], "weight_name" : ["/path/to/blueprintify-sd-xl-10.safetensors","papercut.safetensors"], "adapter_name" : ["blueprint","papercut"], "scale": [0.8, 0.7], "from_pt": ["true", "true"]}'
LoRA parameters:
  • lora_model_name_or_path: HuggingFace repo or local path
  • weight_name: Safetensors filename
  • adapter_name: Unique identifier for the adapter
  • scale: LoRA influence strength (0.0-1.0)
  • from_pt: Whether to convert from PyTorch format

Sharding strategies

Data parallelism (default)

SDXL supports single and multi-host inference with sharding annotations. Data parallelism replicates the model:
python -m src.maxdiffusion.generate_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="ddp" \
  ici_data_parallelism=4

FSDP

Fully shard model parameters to fit larger models:
python -m src.maxdiffusion.generate_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="fsdp" \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=-1

Implementation details

The SDXL pipeline (generate_sdxl.py:src/maxdiffusion/generate_sdxl.py) implements:

Dual text encoding

SDXL uses two text encoders (generate_sdxl.py:93-103):
def get_embeddings(prompt_ids, pipeline, params):
  te_1_inputs = prompt_ids[:, 0, :]
  te_2_inputs = prompt_ids[:, 1, :]
  
  # CLIP ViT-L/14 embeddings
  prompt_embeds = pipeline.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True)
  prompt_embeds = prompt_embeds["hidden_states"][-2]
  
  # OpenCLIP ViT-G/14 embeddings and pooled output
  prompt_embeds_2_out = pipeline.text_encoder_2(te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True)
  prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2]
  text_embeds = prompt_embeds_2_out["text_embeds"]
  
  # Concatenate embeddings
  prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1)
  return prompt_embeds, text_embeds

Additional conditioning

SDXL adds time embeddings for resolution and crop coordinates (generate_sdxl.py:137-150):
add_time_ids = get_add_time_ids(
    (height, width),  # Original resolution
    (0, 0),           # Crop coordinates
    (height, width),  # Target resolution
    batch_size,
    dtype=prompt_embeds.dtype
)

added_cond_kwargs = {
    "text_embeds": pooled_embeds,
    "time_ids": add_time_ids
}

Custom models

Load custom SDXL checkpoints from HuggingFace:
python -m src.maxdiffusion.generate_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="custom_model" \
  pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" \
  from_pt=True \
  revision=main

Output

Generated images are saved as image_sdxl_{i}.png. The pipeline reports:
  • Compile time: Initial compilation duration
  • Inference time: Generation time after compilation

Next steps

ControlNet SDXL

Conditional generation with edge detection

SDXL training

Fine-tune SDXL on custom datasets

Flux inference

Next-generation image synthesis

LoRA training

Train custom LoRA adapters

Build docs developers (and LLMs) love