Overview
Thetrain.py script trains Stable Diffusion 2.x models using the MaxDiffusion framework. It supports distributed training across TPU pods and GPU clusters with configurable parallelism strategies.
Command-line usage
Basic example
Configuration parameters
Run configuration
Name for this training run. Used for organizing outputs and metrics.
GCS bucket path for outputs (e.g.,
gs://my-bucket/). Checkpoints and metrics will be saved here.Local or GCS directory for model outputs.
Model configuration
HuggingFace model identifier or local path to pretrained model.
Path to a specific UNet checkpoint to load.
Model revision/branch to use from HuggingFace.
Data type for model weights. Options:
float32, bfloat16.Data type for layer activations. Options:
float32, bfloat16.JAX precision for matmul and conv operations. Options:
DEFAULT, HIGH, HIGHEST.Load weights from PyTorch format.
If true, randomly initialize UNet weights to train from scratch. Otherwise, load from pretrained model.
Attention configuration
Attention mechanism to use. Options:
dot_product, flash.Whether to split attention head dimensions for sharding.
Use uniform sequence sharding for both self and cross attention.
Pass segment IDs to attention to avoid attending to padding tokens. Improves quality when padding is significant.
Custom block sizes for flash attention.
Text encoder configuration
Enable training of the text encoder along with UNet.
Learning rate for text encoder when
train_text_encoder is true.Training hyperparameters
Learning rate for the optimizer.
Scale learning rate by the number of GPUs/TPUs and batch size.
Maximum number of training steps. Takes priority over
num_train_epochs.Maximum number of training samples to use. -1 means use all samples.
Batch size per device.
Fraction of total steps to use for learning rate warmup.
Random seed for reproducibility.
Optimizer parameters
Exponential decay rate for first moment estimates.
Exponential decay rate for second moment estimates.
Small constant for numerical stability.
Weight decay coefficient for AdamW optimizer.
Maximum gradient norm for gradient clipping.
Loss and noise schedule
SNR-weighted loss gamma parameter. Set to -1.0 to disable.
Configuration for biasing timestep sampling during training.
strategy: Bias strategy. Options:none,earlier,later,rangemultiplier: Bias multiplier (2.0 doubles weight, 0.5 halves it)begin: Start timestep for range strategyend: End timestep for range strategyportion: Fraction of timesteps to bias
Override parameters for the diffusion scheduler.
_class_name: Scheduler class nameprediction_type: Prediction type (e.g.,v_prediction,epsilon)rescale_zero_terminal_snr: Whether to rescale zero terminal SNRtimestep_spacing: Timestep spacing strategy
Dataset configuration
HuggingFace dataset identifier.
Local directory containing training data. Either this or
dataset_name must be set.Dataset split to use for training.
Dataset format type.
Cache image latents and text encoder outputs to reduce memory and speed up training. Only applies to small datasets that fit in memory.
Path to save transformed dataset when caching is enabled.
Name of the image column in the dataset.
Name of the caption column in the dataset.
Image resolution for training.
Whether to center crop images before resizing.
Whether to randomly flip images horizontally.
Shuffle the dataset during training.
HuggingFace access token for private datasets or models.
Parallelism and sharding
Hardware type. Options:
tpu, gpu.Logical mesh axes for parallelism.
Data parallelism across DCN. -1 for auto-sharding.
FSDP parallelism across DCN.
Tensor parallelism across DCN.
Data parallelism within ICI. -1 for auto-sharding.
FSDP parallelism within ICI.
Tensor parallelism within ICI.
Checkpointing
Save checkpoint every N samples. -1 disables checkpointing.
Enable one replica to read checkpoint and broadcast to others.
Metrics and logging
Save metrics such as loss and TFLOPS to GCS.
Write metrics to GCS.
Local file path for storing scalar metrics (for testing).
Tensorboard flush period.
Profiling
Enable JAX profiler.
Skip first N steps when profiling to exclude compilation.
Number of steps to profile.
Generation parameters
Prompt for test image generation during training.
Negative prompt for guidance.
Classifier-free guidance scale.
Number of denoising steps for test generation.
Expected outputs
The training script produces the following outputs:Checkpoints
Whencheckpoint_every is set, model checkpoints are saved to:
- Model weights (UNet and optionally text encoder)
- Optimizer state
- Training step information
Metrics
Training metrics are saved to:- Training loss
- Learning rate
- TFLOPS per device
- Step time