Stable Diffusion XL (SDXL) is a high-resolution text-to-image model that generates images at 1024x1024 resolution with improved quality and detail.
Basic training
TPU training
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
run_name="my_xl_run" \
output_dir="gs://your-bucket/" \
per_device_batch_size=1
GPU training with fused attention
For optimal GPU performance, use Transformer Engine with fused attention:
Install Transformer Engine
Run training with cudnn_flash_te
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
hardware=gpu \
run_name='test-sdxl-train' \
output_dir=/tmp/ \
train_new_unet=true \
train_text_encoder=false \
cache_latents_text_encoder_outputs=true \
max_train_steps=200 \
weights_dtype=bfloat16 \
resolution=512 \
per_device_batch_size=1 \
attention="cudnn_flash_te" \
jit_initializers=False
For GPU training, install Transformer Engine for optimal performance. Without it, training will be significantly slower.
Configuration
The base config is located at src/maxdiffusion/configs/base_xl.yml.
Key parameters
| Parameter | Default | Description |
|---|
pretrained_model_name_or_path | stabilityai/stable-diffusion-xl-base-1.0 | Base SDXL model |
revision | refs/pr/95 | Model revision |
weights_dtype | float32 | Weight precision |
activations_dtype | bfloat16 | Activation precision |
attention | dot_product | Attention mechanism |
resolution | 1024 | Training image resolution |
per_device_batch_size | 2 | Batch size per device |
learning_rate | 4.e-7 | Initial learning rate |
max_train_steps | 200 | Maximum training steps |
Dataset configuration
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
resolution: 1024
image_column: 'image'
caption_column: 'text'
Scheduler configuration
diffusion_scheduler_config:
_class_name: 'FlaxEulerDiscreteScheduler'
prediction_type: 'epsilon'
rescale_zero_terminal_snr: False
timestep_spacing: 'trailing'
Training parameters
learning_rate: 4.e-7
scale_lr: False
max_train_steps: 200
num_train_epochs: 1
warmup_steps_fraction: 0.0
per_device_batch_size: 2
Optimizer configuration
adam_b1: 0.9
adam_b2: 0.999
adam_eps: 1.e-8
adam_weight_decay: 1.e-2
max_grad_norm: 1.0
Parallelism configuration
# Mesh axes
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
# ICI parallelism (within host)
ici_data_parallelism: -1 # Auto-shard
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
# DCN parallelism (across hosts)
dcn_data_parallelism: -1 # Auto-shard
dcn_fsdp_parallelism: 1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
Advanced features
Train from scratch
Initialize a new UNet with random weights:
Text encoder training
Text encoder training is currently not supported for SDXL.
train_text_encoder: False # Not supported
SNR gamma weighting
Enable SNR-based loss weighting:
snr_gamma: 5.0 # -1.0 to disable
Timestep bias
timestep_bias:
strategy: "none" # Options: none, earlier, later, range
multiplier: 1.0
begin: 0
end: 1000
portion: 0.25
Attention mechanisms
SDXL supports multiple attention implementations:
dot_product - Standard attention (TPU/GPU)
flash - Flash attention (TPU v5+)
cudnn_flash_te - NVIDIA Transformer Engine fused attention (GPU only)
Generate images from checkpoint
After training, generate images using your fine-tuned SDXL model:
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_xl.yml \
run_name="my_run" \
pretrained_model_name_or_path=<your_saved_checkpoint_path> \
from_pt=False \
attention=dot_product
Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 9.0
guidance_rescale: 0.0
num_inference_steps: 20
LoRA support
SDXL supports loading LoRA adapters during inference:
lora_config:
lora_model_name_or_path: ["ByteDance/Hyper-SD"]
weight_name: ["Hyper-SDXL-2steps-lora.safetensors"]
adapter_name: ["hyper-sdxl"]
scale: [0.7]
from_pt: [True]
SDXL Lightning
Load SDXL Lightning checkpoints for faster inference:
lightning_from_pt: True
lightning_repo: "ByteDance/SDXL-Lightning"
lightning_ckpt: "sdxl_lightning_4step_unet.safetensors"
ControlNet
Use ControlNet for guided image generation:
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
controlnet_from_pt: True
controlnet_conditioning_scale: 0.5
controlnet_image: 'https://example.com/control-image.jpg'
Profiling
Enable performance profiling:
enable_profiler: True
skip_first_n_steps_for_profiler: 5
profiler_steps: 10
Checkpointing
checkpoint_every: 100 # Save every 100 steps
enable_single_replica_ckpt_restoring: False
Custom datasets
HuggingFace dataset
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
dataset_name="your-username/your-dataset" \
resolution=1024
Local or GCS path
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
train_data_dir="gs://your-bucket/training-data" \
dataset_name=""
Monitoring
View training metrics in TensorBoard:
tensorboard --logdir=gs://your-bucket/my_xl_run/tensorboard/
Metrics include:
- Training loss
- Learning rate
- Step time
- TFLOP/s per device
Best practices
- Use bfloat16 for faster training: Set
weights_dtype=bfloat16 on TPU v5e
- Enable latent caching: Set
cache_latents_text_encoder_outputs=True for faster iteration
- Start with lower resolution: Test with
resolution=512 before scaling to 1024
- Monitor GPU memory: Reduce
per_device_batch_size if you run out of memory
- Use GCS for outputs: Store checkpoints in GCS buckets for durability