Skip to main content
Stable Diffusion XL (SDXL) is a high-resolution text-to-image model that generates images at 1024x1024 resolution with improved quality and detail.

Basic training

TPU training

export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
  run_name="my_xl_run" \
  output_dir="gs://your-bucket/" \
  per_device_batch_size=1

GPU training with fused attention

For optimal GPU performance, use Transformer Engine with fused attention:
1

Install Transformer Engine

2

Run training with cudnn_flash_te

NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
  hardware=gpu \
  run_name='test-sdxl-train' \
  output_dir=/tmp/ \
  train_new_unet=true \
  train_text_encoder=false \
  cache_latents_text_encoder_outputs=true \
  max_train_steps=200 \
  weights_dtype=bfloat16 \
  resolution=512 \
  per_device_batch_size=1 \
  attention="cudnn_flash_te" \
  jit_initializers=False
For GPU training, install Transformer Engine for optimal performance. Without it, training will be significantly slower.

Configuration

The base config is located at src/maxdiffusion/configs/base_xl.yml.

Key parameters

ParameterDefaultDescription
pretrained_model_name_or_pathstabilityai/stable-diffusion-xl-base-1.0Base SDXL model
revisionrefs/pr/95Model revision
weights_dtypefloat32Weight precision
activations_dtypebfloat16Activation precision
attentiondot_productAttention mechanism
resolution1024Training image resolution
per_device_batch_size2Batch size per device
learning_rate4.e-7Initial learning rate
max_train_steps200Maximum training steps

Dataset configuration

dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
resolution: 1024
image_column: 'image'
caption_column: 'text'

Scheduler configuration

diffusion_scheduler_config:
  _class_name: 'FlaxEulerDiscreteScheduler'
  prediction_type: 'epsilon'
  rescale_zero_terminal_snr: False
  timestep_spacing: 'trailing'

Training parameters

learning_rate: 4.e-7
scale_lr: False
max_train_steps: 200
num_train_epochs: 1
warmup_steps_fraction: 0.0
per_device_batch_size: 2

Optimizer configuration

adam_b1: 0.9
adam_b2: 0.999
adam_eps: 1.e-8
adam_weight_decay: 1.e-2
max_grad_norm: 1.0

Parallelism configuration

# Mesh axes
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# ICI parallelism (within host)
ici_data_parallelism: -1  # Auto-shard
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1

# DCN parallelism (across hosts)
dcn_data_parallelism: -1  # Auto-shard
dcn_fsdp_parallelism: 1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1

Advanced features

Train from scratch

Initialize a new UNet with random weights:
train_new_unet: True

Text encoder training

Text encoder training is currently not supported for SDXL.
train_text_encoder: False  # Not supported

SNR gamma weighting

Enable SNR-based loss weighting:
snr_gamma: 5.0  # -1.0 to disable

Timestep bias

timestep_bias:
  strategy: "none"  # Options: none, earlier, later, range
  multiplier: 1.0
  begin: 0
  end: 1000
  portion: 0.25

Attention mechanisms

SDXL supports multiple attention implementations:
  • dot_product - Standard attention (TPU/GPU)
  • flash - Flash attention (TPU v5+)
  • cudnn_flash_te - NVIDIA Transformer Engine fused attention (GPU only)

Generate images from checkpoint

After training, generate images using your fine-tuned SDXL model:
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_xl.yml \
  run_name="my_run" \
  pretrained_model_name_or_path=<your_saved_checkpoint_path> \
  from_pt=False \
  attention=dot_product

Generation parameters

prompt: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 9.0
guidance_rescale: 0.0
num_inference_steps: 20

LoRA support

SDXL supports loading LoRA adapters during inference:
lora_config:
  lora_model_name_or_path: ["ByteDance/Hyper-SD"]
  weight_name: ["Hyper-SDXL-2steps-lora.safetensors"]
  adapter_name: ["hyper-sdxl"]
  scale: [0.7]
  from_pt: [True]

SDXL Lightning

Load SDXL Lightning checkpoints for faster inference:
lightning_from_pt: True
lightning_repo: "ByteDance/SDXL-Lightning"
lightning_ckpt: "sdxl_lightning_4step_unet.safetensors"

ControlNet

Use ControlNet for guided image generation:
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
controlnet_from_pt: True
controlnet_conditioning_scale: 0.5
controlnet_image: 'https://example.com/control-image.jpg'

Profiling

Enable performance profiling:
enable_profiler: True
skip_first_n_steps_for_profiler: 5
profiler_steps: 10

Checkpointing

checkpoint_every: 100  # Save every 100 steps
enable_single_replica_ckpt_restoring: False

Custom datasets

HuggingFace dataset

python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
  dataset_name="your-username/your-dataset" \
  resolution=1024

Local or GCS path

python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
  train_data_dir="gs://your-bucket/training-data" \
  dataset_name=""

Monitoring

View training metrics in TensorBoard:
tensorboard --logdir=gs://your-bucket/my_xl_run/tensorboard/
Metrics include:
  • Training loss
  • Learning rate
  • Step time
  • TFLOP/s per device

Best practices

  1. Use bfloat16 for faster training: Set weights_dtype=bfloat16 on TPU v5e
  2. Enable latent caching: Set cache_latents_text_encoder_outputs=True for faster iteration
  3. Start with lower resolution: Test with resolution=512 before scaling to 1024
  4. Monitor GPU memory: Reduce per_device_batch_size if you run out of memory
  5. Use GCS for outputs: Store checkpoints in GCS buckets for durability

Build docs developers (and LLMs) love