Skip to main content

Overview

The train.py script trains Stable Diffusion 2.x models using the MaxDiffusion framework. It supports distributed training across TPU pods and GPU clusters with configurable parallelism strategies.

Command-line usage

python -m src.maxdiffusion.train src/maxdiffusion/configs/base21.yml [OPTIONS]

Basic example

export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.train \
  src/maxdiffusion/configs/base21.yml \
  run_name="my_run" \
  jax_cache_dir=gs://your-bucket/cache_dir \
  activations_dtype=float32 \
  weights_dtype=float32 \
  per_device_batch_size=2 \
  precision=DEFAULT \
  dataset_save_location=/tmp/my_dataset/ \
  output_dir=gs://your-bucket/ \
  attention=flash

Configuration parameters

Run configuration

run_name
string
required
Name for this training run. Used for organizing outputs and metrics.
base_output_directory
string
GCS bucket path for outputs (e.g., gs://my-bucket/). Checkpoints and metrics will be saved here.
output_dir
string
default:"sd-model-finetuned"
Local or GCS directory for model outputs.

Model configuration

pretrained_model_name_or_path
string
default:"stabilityai/stable-diffusion-2-1"
HuggingFace model identifier or local path to pretrained model.
unet_checkpoint
string
default:""
Path to a specific UNet checkpoint to load.
revision
string
default:"bf16"
Model revision/branch to use from HuggingFace.
weights_dtype
string
default:"float32"
Data type for model weights. Options: float32, bfloat16.
activations_dtype
string
default:"bfloat16"
Data type for layer activations. Options: float32, bfloat16.
precision
string
default:"DEFAULT"
JAX precision for matmul and conv operations. Options: DEFAULT, HIGH, HIGHEST.
from_pt
boolean
default:false
Load weights from PyTorch format.
train_new_unet
boolean
default:false
If true, randomly initialize UNet weights to train from scratch. Otherwise, load from pretrained model.

Attention configuration

attention
string
default:"dot_product"
Attention mechanism to use. Options: dot_product, flash.
split_head_dim
boolean
default:true
Whether to split attention head dimensions for sharding.
attention_sharding_uniform
boolean
default:true
Use uniform sequence sharding for both self and cross attention.
mask_padding_tokens
boolean
default:true
Pass segment IDs to attention to avoid attending to padding tokens. Improves quality when padding is significant.
flash_block_sizes
object
Custom block sizes for flash attention.

Text encoder configuration

train_text_encoder
boolean
default:false
Enable training of the text encoder along with UNet.
text_encoder_learning_rate
number
Learning rate for text encoder when train_text_encoder is true.

Training hyperparameters

learning_rate
number
Learning rate for the optimizer.
scale_lr
boolean
default:false
Scale learning rate by the number of GPUs/TPUs and batch size.
max_train_steps
number
default:800
Maximum number of training steps. Takes priority over num_train_epochs.
max_train_samples
number
Maximum number of training samples to use. -1 means use all samples.
per_device_batch_size
number
default:1
Batch size per device.
warmup_steps_fraction
number
Fraction of total steps to use for learning rate warmup.
seed
number
default:0
Random seed for reproducibility.

Optimizer parameters

adam_b1
number
Exponential decay rate for first moment estimates.
adam_b2
number
Exponential decay rate for second moment estimates.
adam_eps
number
Small constant for numerical stability.
adam_weight_decay
number
Weight decay coefficient for AdamW optimizer.
max_grad_norm
number
Maximum gradient norm for gradient clipping.

Loss and noise schedule

snr_gamma
number
SNR-weighted loss gamma parameter. Set to -1.0 to disable.
timestep_bias
object
Configuration for biasing timestep sampling during training.
  • strategy: Bias strategy. Options: none, earlier, later, range
  • multiplier: Bias multiplier (2.0 doubles weight, 0.5 halves it)
  • begin: Start timestep for range strategy
  • end: End timestep for range strategy
  • portion: Fraction of timesteps to bias
diffusion_scheduler_config
object
Override parameters for the diffusion scheduler.
  • _class_name: Scheduler class name
  • prediction_type: Prediction type (e.g., v_prediction, epsilon)
  • rescale_zero_terminal_snr: Whether to rescale zero terminal SNR
  • timestep_spacing: Timestep spacing strategy

Dataset configuration

dataset_name
string
default:"diffusers/pokemon-gpt4-captions"
HuggingFace dataset identifier.
train_data_dir
string
default:""
Local directory containing training data. Either this or dataset_name must be set.
train_split
string
default:"train"
Dataset split to use for training.
dataset_type
string
default:"tf"
Dataset format type.
cache_latents_text_encoder_outputs
boolean
default:true
Cache image latents and text encoder outputs to reduce memory and speed up training. Only applies to small datasets that fit in memory.
dataset_save_location
string
default:"/tmp/pokemon-gpt4-captions_sd21"
Path to save transformed dataset when caching is enabled.
image_column
string
default:"image"
Name of the image column in the dataset.
caption_column
string
default:"text"
Name of the caption column in the dataset.
resolution
number
default:768
Image resolution for training.
center_crop
boolean
default:false
Whether to center crop images before resizing.
random_flip
boolean
default:false
Whether to randomly flip images horizontally.
enable_data_shuffling
boolean
default:true
Shuffle the dataset during training.
hf_access_token
string
default:""
HuggingFace access token for private datasets or models.

Parallelism and sharding

hardware
string
default:"tpu"
Hardware type. Options: tpu, gpu.
mesh_axes
array
default:["data","fsdp","context","tensor"]
Logical mesh axes for parallelism.
dcn_data_parallelism
number
Data parallelism across DCN. -1 for auto-sharding.
dcn_fsdp_parallelism
number
default:1
FSDP parallelism across DCN.
dcn_tensor_parallelism
number
default:1
Tensor parallelism across DCN.
ici_data_parallelism
number
Data parallelism within ICI. -1 for auto-sharding.
ici_fsdp_parallelism
number
default:1
FSDP parallelism within ICI.
ici_tensor_parallelism
number
default:1
Tensor parallelism within ICI.

Checkpointing

checkpoint_every
number
Save checkpoint every N samples. -1 disables checkpointing.
enable_single_replica_ckpt_restoring
boolean
default:false
Enable one replica to read checkpoint and broadcast to others.

Metrics and logging

write_metrics
boolean
default:true
Save metrics such as loss and TFLOPS to GCS.
gcs_metrics
boolean
default:true
Write metrics to GCS.
metrics_file
string
default:""
Local file path for storing scalar metrics (for testing).
log_period
number
default:10000000000
Tensorboard flush period.

Profiling

enable_profiler
boolean
default:false
Enable JAX profiler.
skip_first_n_steps_for_profiler
number
default:1
Skip first N steps when profiling to exclude compilation.
profiler_steps
number
default:5
Number of steps to profile.

Generation parameters

prompt
string
Prompt for test image generation during training.
negative_prompt
string
default:"purple, red"
Negative prompt for guidance.
guidance_scale
number
Classifier-free guidance scale.
num_inference_steps
number
default:30
Number of denoising steps for test generation.

Expected outputs

The training script produces the following outputs:

Checkpoints

When checkpoint_every is set, model checkpoints are saved to:
{base_output_directory}/{run_name}/checkpoints/
Checkpoints include:
  • Model weights (UNet and optionally text encoder)
  • Optimizer state
  • Training step information

Metrics

Training metrics are saved to:
{base_output_directory}/{run_name}/metrics/
Metrics include:
  • Training loss
  • Learning rate
  • TFLOPS per device
  • Step time

Tensorboard logs

Tensorboard logs are written to:
{base_output_directory}/{run_name}/tensorboard/
View logs with:
tensorboard --logdir=gs://your-bucket/{run_name}/tensorboard/

Training output

During training, you’ll see output like:
***** Running training *****
Instantaneous batch size per device = 2
Total train batch size (w. parallel & distributed) = 8
Total optimization steps = 800
completed step: 0, seconds: 2.5, TFLOP/s/device: 45.2, loss: 0.234
completed step: 1, seconds: 1.2, TFLOP/s/device: 94.3, loss: 0.198

Build docs developers (and LLMs) love