Overview
Thetrain_sdxl.py script trains Stable Diffusion XL models using the MaxDiffusion framework. It supports distributed training across TPU pods and GPU clusters with advanced attention mechanisms including fused attention via Transformer Engine.
Command-line usage
Basic example (TPU)
GPU with fused attention
Configuration parameters
Run configuration
Name for this training run. Used for organizing outputs and metrics.
GCS bucket path for outputs (e.g.,
gs://my-bucket/). Checkpoints and metrics will be saved here.Local or GCS directory for model outputs.
Model configuration
HuggingFace model identifier or local path to pretrained SDXL model.
Path to a specific UNet checkpoint to load.
Model revision/branch to use from HuggingFace.
Data type for model weights. Options:
float32, bfloat16. Use bfloat16 on TPU v5e for inference.Data type for layer activations. Options:
float32, bfloat16.JAX precision for matmul and conv operations. Options:
DEFAULT, HIGH, HIGHEST.Load weights from PyTorch format.
If true, randomly initialize UNet weights to train from scratch. Otherwise, load from pretrained model.
Attention configuration
Attention mechanism to use. Options:
dot_product, flash, cudnn_flash_te (GPU only).Whether to split attention head dimensions for sharding.
Use uniform sequence sharding for both self and cross attention.
Pass segment IDs to attention to avoid attending to padding tokens. Improves quality when padding is significant.
Custom block sizes for flash attention.
Text encoder configuration
Enable training of the text encoder. Currently not supported for SDXL.
Learning rate for text encoder when
train_text_encoder is true.Training hyperparameters
Learning rate for the optimizer.
Scale learning rate by the number of GPUs/TPUs and batch size.
Maximum number of training steps. Takes priority over
num_train_epochs.Number of training epochs.
Maximum number of training samples to use. -1 means use all samples.
Batch size per device.
Fraction of total steps to use for learning rate warmup.
Random seed for reproducibility.
Optimizer parameters
Exponential decay rate for first moment estimates.
Exponential decay rate for second moment estimates.
Small constant for numerical stability.
Weight decay coefficient for AdamW optimizer.
Maximum gradient norm for gradient clipping.
Loss and noise schedule
SNR-weighted loss gamma parameter. Set to -1.0 to disable.
Configuration for biasing timestep sampling during training.
strategy: Bias strategy. Options:none,earlier,later,rangemultiplier: Bias multiplier (2.0 doubles weight, 0.5 halves it)begin: Start timestep for range strategyend: End timestep for range strategyportion: Fraction of timesteps to bias
Override parameters for the diffusion scheduler.
_class_name: Scheduler class name (default:FlaxEulerDiscreteScheduler)prediction_type: Prediction type (default:epsilon)rescale_zero_terminal_snr: Whether to rescale zero terminal SNRtimestep_spacing: Timestep spacing strategy (default:trailing)
Dataset configuration
HuggingFace dataset identifier.
Local directory containing training data. Either this or
dataset_name must be set.Dataset split to use for training.
Dataset format type.
Cache image latents and text encoder outputs to reduce memory and speed up training. Only applies to small datasets that fit in memory.
Path to save transformed dataset when caching is enabled.
Name of the image column in the dataset.
Name of the caption column in the dataset.
Image resolution for training. SDXL is typically trained at 1024x1024.
Whether to center crop images before resizing.
Whether to randomly flip images horizontally.
Shuffle the dataset during training.
HuggingFace access token for private datasets or models.
Parallelism and sharding
Hardware type. Options:
tpu, gpu.Logical mesh axes for parallelism.
Data parallelism across DCN. -1 for auto-sharding.
FSDP parallelism across DCN.
Tensor parallelism across DCN.
Data parallelism within ICI. -1 for auto-sharding.
FSDP parallelism within ICI.
Tensor parallelism within ICI.
Checkpointing
Save checkpoint every N samples. -1 disables checkpointing.
Enable one replica to read checkpoint and broadcast to others.
Metrics and logging
Save metrics such as loss and TFLOPS to GCS.
Write metrics to GCS.
Local file path for storing scalar metrics (for testing).
Tensorboard flush period.
Profiling
Enable JAX profiler.
Skip first N steps when profiling to exclude compilation.
Number of steps to profile.
Generation parameters
Prompt for test image generation during training.
Negative prompt for guidance.
Enable classifier-free guidance.
Classifier-free guidance scale.
Number of denoising steps for test generation.
SDXL Lightning parameters
Load Lightning weights from PyTorch.
HuggingFace repo for SDXL Lightning (e.g.,
ByteDance/SDXL-Lightning).SDXL Lightning checkpoint filename (e.g.,
sdxl_lightning_4step_unet.safetensors).LoRA configuration
Configuration for loading LoRA adapters during inference.
lora_model_name_or_path: List of LoRA model paths or HuggingFace reposweight_name: List of weight filenamesadapter_name: List of adapter namesscale: List of scaling factorsfrom_pt: List of booleans indicating PyTorch format
ControlNet parameters
HuggingFace model identifier for ControlNet.
Load ControlNet weights from PyTorch.
Conditioning scale for ControlNet.
URL or path to conditioning image for ControlNet.
Quantization
Quantization configuration.
Shard count for quantization range finding. Default is number of slices.
Enable qwix quantization.
Expected outputs
The training script produces the following outputs:Checkpoints
Whencheckpoint_every is set, model checkpoints are saved to:
- UNet weights
- Text encoder weights (if trained)
- Optimizer state
- Training step information
Metrics
Training metrics are saved to:- Training loss
- Learning rate
- TFLOPS per device
- Step time