Skip to main content
MaxDiffusion supports training Stable Diffusion 1.4 and Stable Diffusion 2 Base models for text-to-image generation.

Stable Diffusion 2 Base training

Basic training command

export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
  run_name="my_run" \
  jax_cache_dir=gs://your-bucket/cache_dir \
  activations_dtype=float32 \
  weights_dtype=float32 \
  per_device_batch_size=2 \
  precision=DEFAULT \
  dataset_save_location=/tmp/my_dataset/ \
  output_dir=gs://your-bucket/ \
  attention=flash

Configuration

The base config is located at src/maxdiffusion/configs/base_2_base.yml.

Key parameters

ParameterDefaultDescription
pretrained_model_name_or_pathstabilityai/stable-diffusion-2-baseBase model to fine-tune
revisionmainModel revision
weights_dtypefloat32Weight precision
activations_dtypebfloat16Activation precision
attentionflashAttention mechanism (dot_product, flash)
resolution512Training image resolution
per_device_batch_size1Batch size per device
learning_rate1.e-7Initial learning rate
max_train_steps20Maximum training steps
dataset_namediffusers/pokemon-gpt4-captionsHuggingFace dataset

Dataset configuration

dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
dataset_save_location: '/tmp/pokemon-gpt4-captions'
image_column: 'image'
caption_column: 'text'
resolution: 512
center_crop: False
random_flip: False

Optimizer parameters

adam_b1: 0.9
adam_b2: 0.999
adam_eps: 1.e-8
adam_weight_decay: 1.e-2
max_grad_norm: 1.0

Advanced features

SNR weighting (https://arxiv.org/pdf/2305.08891.pdf):
snr_gamma: -1.0  # Set to positive value to enable
Timestep bias:
timestep_bias:
  strategy: "none"  # Options: none, earlier, later, range
  multiplier: 1.0
  begin: 0
  end: 1000
  portion: 0.25
Train new UNet from scratch:
train_new_unet: False  # Set to True to initialize random weights
Train text encoder:
train_text_encoder: False  # Set to True to fine-tune text encoder
text_encoder_learning_rate: 4.25e-6

Parallelism configuration

# Automatic sharding (recommended)
ici_data_parallelism: -1  # Auto-shard
ici_fsdp_parallelism: 1
ici_tensor_parallelism: 1

# Manual sharding for multi-host
dcn_data_parallelism: -1  # Auto-shard across hosts
dcn_fsdp_parallelism: 1

Generate images from checkpoint

After training, generate images using your fine-tuned model:
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml \
  run_name="my_run" \
  output_dir=gs://your-bucket/ \
  from_pt=False \
  attention=dot_product

Stable Diffusion 1.4 training

Basic training command

export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.train src/maxdiffusion/configs/base14.yml \
  run_name="my_run" \
  jax_cache_dir=gs://your-bucket/cache_dir \
  activations_dtype=float32 \
  weights_dtype=float32 \
  per_device_batch_size=2 \
  precision=DEFAULT \
  dataset_save_location=/tmp/my_dataset/ \
  output_dir=gs://your-bucket/ \
  attention=flash

Configuration

The base config is located at src/maxdiffusion/configs/base14.yml.

Key differences from SD 2 Base

ParameterSD 1.4SD 2 Base
pretrained_model_name_or_pathCompVis/stable-diffusion-v1-4stabilityai/stable-diffusion-2-base
revisionflaxmain
from_ptFalseTrue
max_train_steps80020
dataset_save_location/tmp/pokemon-gpt4-captions_sd15/tmp/pokemon-gpt4-captions

Generate images from checkpoint

python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml \
  run_name="my_run" \
  output_dir=gs://your-bucket/ \
  from_pt=False \
  attention=dot_product
The same generate script works for both SD 1.4 and SD 2 Base models.

Profiling and monitoring

Enable profiling to analyze training performance:
enable_profiler: True
skip_first_n_steps_for_profiler: 1
profiler_steps: 5
View TensorBoard metrics:
tensorboard --logdir=gs://your-bucket/my_run/tensorboard/

Checkpointing

Save checkpoints periodically:
checkpoint_every: 100  # Save every 100 steps
enable_single_replica_ckpt_restoring: False

Custom datasets

To use your own dataset:

Option 1: HuggingFace dataset

python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
  dataset_name="your-username/your-dataset" \
  image_column="image" \
  caption_column="text"

Option 2: Local directory

python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
  train_data_dir="/path/to/your/images" \
  dataset_name=""

Option 3: GCS bucket

python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
  train_data_dir="gs://your-bucket/training-data" \
  dataset_name=""

Advanced configuration

Precision settings

precision: "DEFAULT"  # Options: DEFAULT, HIGH, HIGHEST
weights_dtype: 'float32'  # Options: float32, bfloat16
activations_dtype: 'bfloat16'  # Options: float32, bfloat16
For best precision, use float32 weights and activations with HIGHEST precision. This will increase training time.

Attention mechanisms

  • dot_product - Standard attention (slower, compatible with all hardware)
  • flash - Flash attention (faster, requires TPU v5+)

Learning rate scheduling

learning_rate: 1.e-7
scale_lr: False
warmup_steps_fraction: 0.0
learning_rate_schedule_steps: -1  # -1 uses max_train_steps

Data preprocessing

cache_latents_text_encoder_outputs: True  # Cache for faster training
center_crop: False
random_flip: False
enable_data_shuffling: True
tokenize_captions_num_proc: 4
transform_images_num_proc: 4

Build docs developers (and LLMs) love