Skip to main content

Overview

The train_sdxl.py script trains Stable Diffusion XL models using the MaxDiffusion framework. It supports distributed training across TPU pods and GPU clusters with advanced attention mechanisms including fused attention via Transformer Engine.

Command-line usage

python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml [OPTIONS]

Basic example (TPU)

export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_xl_run" \
  output_dir="gs://your-bucket/" \
  per_device_batch_size=1

GPU with fused attention

NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  hardware=gpu \
  run_name='test-sdxl-train' \
  output_dir=/tmp/ \
  train_new_unet=true \
  train_text_encoder=false \
  cache_latents_text_encoder_outputs=true \
  max_train_steps=200 \
  weights_dtype=bfloat16 \
  resolution=512 \
  per_device_batch_size=1 \
  attention="cudnn_flash_te" \
  jit_initializers=False

Configuration parameters

Run configuration

run_name
string
required
Name for this training run. Used for organizing outputs and metrics.
base_output_directory
string
GCS bucket path for outputs (e.g., gs://my-bucket/). Checkpoints and metrics will be saved here.
output_dir
string
default:"sdxl-model-finetuned"
Local or GCS directory for model outputs.

Model configuration

pretrained_model_name_or_path
string
default:"stabilityai/stable-diffusion-xl-base-1.0"
HuggingFace model identifier or local path to pretrained SDXL model.
unet_checkpoint
string
default:""
Path to a specific UNet checkpoint to load.
revision
string
default:"refs/pr/95"
Model revision/branch to use from HuggingFace.
weights_dtype
string
default:"float32"
Data type for model weights. Options: float32, bfloat16. Use bfloat16 on TPU v5e for inference.
activations_dtype
string
default:"bfloat16"
Data type for layer activations. Options: float32, bfloat16.
precision
string
default:"DEFAULT"
JAX precision for matmul and conv operations. Options: DEFAULT, HIGH, HIGHEST.
from_pt
boolean
default:false
Load weights from PyTorch format.
train_new_unet
boolean
default:false
If true, randomly initialize UNet weights to train from scratch. Otherwise, load from pretrained model.

Attention configuration

attention
string
default:"dot_product"
Attention mechanism to use. Options: dot_product, flash, cudnn_flash_te (GPU only).
split_head_dim
boolean
default:true
Whether to split attention head dimensions for sharding.
attention_sharding_uniform
boolean
default:true
Use uniform sequence sharding for both self and cross attention.
mask_padding_tokens
boolean
default:true
Pass segment IDs to attention to avoid attending to padding tokens. Improves quality when padding is significant.
flash_block_sizes
object
Custom block sizes for flash attention.

Text encoder configuration

train_text_encoder
boolean
default:false
Enable training of the text encoder. Currently not supported for SDXL.
text_encoder_learning_rate
number
Learning rate for text encoder when train_text_encoder is true.

Training hyperparameters

learning_rate
number
Learning rate for the optimizer.
scale_lr
boolean
default:false
Scale learning rate by the number of GPUs/TPUs and batch size.
max_train_steps
number
default:200
Maximum number of training steps. Takes priority over num_train_epochs.
num_train_epochs
number
default:1
Number of training epochs.
max_train_samples
number
Maximum number of training samples to use. -1 means use all samples.
per_device_batch_size
number
default:2
Batch size per device.
warmup_steps_fraction
number
Fraction of total steps to use for learning rate warmup.
seed
number
default:0
Random seed for reproducibility.

Optimizer parameters

adam_b1
number
Exponential decay rate for first moment estimates.
adam_b2
number
Exponential decay rate for second moment estimates.
adam_eps
number
Small constant for numerical stability.
adam_weight_decay
number
Weight decay coefficient for AdamW optimizer.
max_grad_norm
number
Maximum gradient norm for gradient clipping.

Loss and noise schedule

snr_gamma
number
SNR-weighted loss gamma parameter. Set to -1.0 to disable.
timestep_bias
object
Configuration for biasing timestep sampling during training.
  • strategy: Bias strategy. Options: none, earlier, later, range
  • multiplier: Bias multiplier (2.0 doubles weight, 0.5 halves it)
  • begin: Start timestep for range strategy
  • end: End timestep for range strategy
  • portion: Fraction of timesteps to bias
diffusion_scheduler_config
object
Override parameters for the diffusion scheduler.
  • _class_name: Scheduler class name (default: FlaxEulerDiscreteScheduler)
  • prediction_type: Prediction type (default: epsilon)
  • rescale_zero_terminal_snr: Whether to rescale zero terminal SNR
  • timestep_spacing: Timestep spacing strategy (default: trailing)

Dataset configuration

dataset_name
string
default:"diffusers/pokemon-gpt4-captions"
HuggingFace dataset identifier.
train_data_dir
string
default:""
Local directory containing training data. Either this or dataset_name must be set.
train_split
string
default:"train"
Dataset split to use for training.
dataset_type
string
default:"tf"
Dataset format type.
cache_latents_text_encoder_outputs
boolean
default:true
Cache image latents and text encoder outputs to reduce memory and speed up training. Only applies to small datasets that fit in memory.
dataset_save_location
string
default:"/tmp/pokemon-gpt4-captions_xl"
Path to save transformed dataset when caching is enabled.
image_column
string
default:"image"
Name of the image column in the dataset.
caption_column
string
default:"text"
Name of the caption column in the dataset.
resolution
number
default:1024
Image resolution for training. SDXL is typically trained at 1024x1024.
center_crop
boolean
default:false
Whether to center crop images before resizing.
random_flip
boolean
default:false
Whether to randomly flip images horizontally.
enable_data_shuffling
boolean
default:true
Shuffle the dataset during training.
hf_access_token
string
default:""
HuggingFace access token for private datasets or models.

Parallelism and sharding

hardware
string
default:"tpu"
Hardware type. Options: tpu, gpu.
mesh_axes
array
default:["data","fsdp","context","tensor"]
Logical mesh axes for parallelism.
dcn_data_parallelism
number
Data parallelism across DCN. -1 for auto-sharding.
dcn_fsdp_parallelism
number
default:1
FSDP parallelism across DCN.
dcn_tensor_parallelism
number
default:1
Tensor parallelism across DCN.
ici_data_parallelism
number
Data parallelism within ICI. -1 for auto-sharding.
ici_fsdp_parallelism
number
default:1
FSDP parallelism within ICI.
ici_tensor_parallelism
number
default:1
Tensor parallelism within ICI.

Checkpointing

checkpoint_every
number
Save checkpoint every N samples. -1 disables checkpointing.
enable_single_replica_ckpt_restoring
boolean
default:false
Enable one replica to read checkpoint and broadcast to others.

Metrics and logging

write_metrics
boolean
default:true
Save metrics such as loss and TFLOPS to GCS.
gcs_metrics
boolean
default:false
Write metrics to GCS.
metrics_file
string
default:""
Local file path for storing scalar metrics (for testing).
log_period
number
default:100
Tensorboard flush period.

Profiling

enable_profiler
boolean
default:false
Enable JAX profiler.
skip_first_n_steps_for_profiler
number
default:5
Skip first N steps when profiling to exclude compilation.
profiler_steps
number
default:10
Number of steps to profile.

Generation parameters

prompt
string
Prompt for test image generation during training.
negative_prompt
string
default:"purple, red"
Negative prompt for guidance.
do_classifier_free_guidance
boolean
default:true
Enable classifier-free guidance.
guidance_scale
number
Classifier-free guidance scale.
num_inference_steps
number
default:20
Number of denoising steps for test generation.

SDXL Lightning parameters

lightning_from_pt
boolean
default:true
Load Lightning weights from PyTorch.
lightning_repo
string
default:""
HuggingFace repo for SDXL Lightning (e.g., ByteDance/SDXL-Lightning).
lightning_ckpt
string
default:""
SDXL Lightning checkpoint filename (e.g., sdxl_lightning_4step_unet.safetensors).

LoRA configuration

lora_config
object
Configuration for loading LoRA adapters during inference.
  • lora_model_name_or_path: List of LoRA model paths or HuggingFace repos
  • weight_name: List of weight filenames
  • adapter_name: List of adapter names
  • scale: List of scaling factors
  • from_pt: List of booleans indicating PyTorch format

ControlNet parameters

controlnet_model_name_or_path
string
default:"diffusers/controlnet-canny-sdxl-1.0"
HuggingFace model identifier for ControlNet.
controlnet_from_pt
boolean
default:true
Load ControlNet weights from PyTorch.
controlnet_conditioning_scale
number
Conditioning scale for ControlNet.
controlnet_image
string
URL or path to conditioning image for ControlNet.

Quantization

quantization
string
default:""
Quantization configuration.
quantization_local_shard_count
number
Shard count for quantization range finding. Default is number of slices.
use_qwix_quantization
boolean
default:false
Enable qwix quantization.

Expected outputs

The training script produces the following outputs:

Checkpoints

When checkpoint_every is set, model checkpoints are saved to:
{base_output_directory}/{run_name}/checkpoints/
Checkpoints include:
  • UNet weights
  • Text encoder weights (if trained)
  • Optimizer state
  • Training step information

Metrics

Training metrics are saved to:
{base_output_directory}/{run_name}/metrics/
Metrics include:
  • Training loss
  • Learning rate
  • TFLOPS per device
  • Step time

Tensorboard logs

Tensorboard logs are written to:
{base_output_directory}/{run_name}/tensorboard/
View logs with:
tensorboard --logdir=gs://your-bucket/{run_name}/tensorboard/

Generating images from trained checkpoint

After training, generate images with:
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_xl.yml \
  run_name="my_run" \
  pretrained_model_name_or_path=<your_saved_checkpoint_path> \
  from_pt=False \
  attention=dot_product

Build docs developers (and LLMs) love