MaxDiffusion supports training Stable Diffusion 1.4 and Stable Diffusion 2 Base models for text-to-image generation.
Stable Diffusion 2 Base training
Basic training command
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
run_name="my_run" \
jax_cache_dir=gs://your-bucket/cache_dir \
activations_dtype=float32 \
weights_dtype=float32 \
per_device_batch_size=2 \
precision=DEFAULT \
dataset_save_location=/tmp/my_dataset/ \
output_dir=gs://your-bucket/ \
attention=flash
Configuration
The base config is located at src/maxdiffusion/configs/base_2_base.yml.
Key parameters
| Parameter | Default | Description |
|---|
pretrained_model_name_or_path | stabilityai/stable-diffusion-2-base | Base model to fine-tune |
revision | main | Model revision |
weights_dtype | float32 | Weight precision |
activations_dtype | bfloat16 | Activation precision |
attention | flash | Attention mechanism (dot_product, flash) |
resolution | 512 | Training image resolution |
per_device_batch_size | 1 | Batch size per device |
learning_rate | 1.e-7 | Initial learning rate |
max_train_steps | 20 | Maximum training steps |
dataset_name | diffusers/pokemon-gpt4-captions | HuggingFace dataset |
Dataset configuration
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
dataset_save_location: '/tmp/pokemon-gpt4-captions'
image_column: 'image'
caption_column: 'text'
resolution: 512
center_crop: False
random_flip: False
Optimizer parameters
adam_b1: 0.9
adam_b2: 0.999
adam_eps: 1.e-8
adam_weight_decay: 1.e-2
max_grad_norm: 1.0
Advanced features
SNR weighting (https://arxiv.org/pdf/2305.08891.pdf):
snr_gamma: -1.0 # Set to positive value to enable
Timestep bias:
timestep_bias:
strategy: "none" # Options: none, earlier, later, range
multiplier: 1.0
begin: 0
end: 1000
portion: 0.25
Train new UNet from scratch:
train_new_unet: False # Set to True to initialize random weights
Train text encoder:
train_text_encoder: False # Set to True to fine-tune text encoder
text_encoder_learning_rate: 4.25e-6
Parallelism configuration
# Automatic sharding (recommended)
ici_data_parallelism: -1 # Auto-shard
ici_fsdp_parallelism: 1
ici_tensor_parallelism: 1
# Manual sharding for multi-host
dcn_data_parallelism: -1 # Auto-shard across hosts
dcn_fsdp_parallelism: 1
Generate images from checkpoint
After training, generate images using your fine-tuned model:
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml \
run_name="my_run" \
output_dir=gs://your-bucket/ \
from_pt=False \
attention=dot_product
Stable Diffusion 1.4 training
Basic training command
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.train src/maxdiffusion/configs/base14.yml \
run_name="my_run" \
jax_cache_dir=gs://your-bucket/cache_dir \
activations_dtype=float32 \
weights_dtype=float32 \
per_device_batch_size=2 \
precision=DEFAULT \
dataset_save_location=/tmp/my_dataset/ \
output_dir=gs://your-bucket/ \
attention=flash
Configuration
The base config is located at src/maxdiffusion/configs/base14.yml.
Key differences from SD 2 Base
| Parameter | SD 1.4 | SD 2 Base |
|---|
pretrained_model_name_or_path | CompVis/stable-diffusion-v1-4 | stabilityai/stable-diffusion-2-base |
revision | flax | main |
from_pt | False | True |
max_train_steps | 800 | 20 |
dataset_save_location | /tmp/pokemon-gpt4-captions_sd15 | /tmp/pokemon-gpt4-captions |
Generate images from checkpoint
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml \
run_name="my_run" \
output_dir=gs://your-bucket/ \
from_pt=False \
attention=dot_product
The same generate script works for both SD 1.4 and SD 2 Base models.
Profiling and monitoring
Enable profiling to analyze training performance:
enable_profiler: True
skip_first_n_steps_for_profiler: 1
profiler_steps: 5
View TensorBoard metrics:
tensorboard --logdir=gs://your-bucket/my_run/tensorboard/
Checkpointing
Save checkpoints periodically:
checkpoint_every: 100 # Save every 100 steps
enable_single_replica_ckpt_restoring: False
Custom datasets
To use your own dataset:
Option 1: HuggingFace dataset
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
dataset_name="your-username/your-dataset" \
image_column="image" \
caption_column="text"
Option 2: Local directory
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
train_data_dir="/path/to/your/images" \
dataset_name=""
Option 3: GCS bucket
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
train_data_dir="gs://your-bucket/training-data" \
dataset_name=""
Advanced configuration
Precision settings
precision: "DEFAULT" # Options: DEFAULT, HIGH, HIGHEST
weights_dtype: 'float32' # Options: float32, bfloat16
activations_dtype: 'bfloat16' # Options: float32, bfloat16
For best precision, use float32 weights and activations with HIGHEST precision. This will increase training time.
Attention mechanisms
dot_product - Standard attention (slower, compatible with all hardware)
flash - Flash attention (faster, requires TPU v5+)
Learning rate scheduling
learning_rate: 1.e-7
scale_lr: False
warmup_steps_fraction: 0.0
learning_rate_schedule_steps: -1 # -1 uses max_train_steps
Data preprocessing
cache_latents_text_encoder_outputs: True # Cache for faster training
center_crop: False
random_flip: False
enable_data_shuffling: True
tokenize_captions_num_proc: 4
transform_images_num_proc: 4