Skip to main content
MaxDiffusion supports training and fine-tuning various diffusion models on TPUs and GPUs. This section covers all available training options and workflows.

Supported models

MaxDiffusion supports training the following models:
  • Stable Diffusion 1.4 - Text-to-image generation
  • Stable Diffusion 2 Base - Text-to-image generation with improved quality
  • Stable Diffusion XL (SDXL) - High-resolution text-to-image generation at 1024x1024
  • Flux Dev - Advanced text-to-image model with transformer architecture
  • Wan 2.1 - Video generation (text-to-video and image-to-video)
  • Dreambooth - Personalized fine-tuning for Stable Diffusion 1.x and 2.x

Hardware requirements

Minimum requirements

  • Ubuntu 22.04
  • Python 3.12
  • TensorFlow >= 2.12.0

Supported accelerators

  • TPU: v5p, v5e, v6e (Trillium)
  • GPU: NVIDIA GPUs with CUDA support
For GPU training with optimal performance, install Transformer Engine for fused attention kernels.

Training workflow

The typical training workflow consists of:
1

Prepare your dataset

Organize your dataset with images and captions. MaxDiffusion supports HuggingFace datasets, TFRecords, and local directories.
2

Configure training parameters

Choose a base config file from src/maxdiffusion/configs/ and override parameters as needed.
3

Run training

Execute the training script with your config and parameter overrides.
4

Monitor progress

Track metrics using TensorBoard and profiler outputs.
5

Evaluate and generate

Test your fine-tuned model by generating images or videos.

Common training parameters

All training scripts share common configuration parameters:

Model parameters

  • pretrained_model_name_or_path - Base model to fine-tune
  • weights_dtype - Weight precision (float32, bfloat16)
  • activations_dtype - Activation precision (float32, bfloat16)
  • attention - Attention mechanism (dot_product, flash, cudnn_flash_te)

Dataset parameters

  • dataset_name - HuggingFace dataset name
  • train_data_dir - Local or GCS path to training data
  • resolution - Training image resolution
  • per_device_batch_size - Batch size per device

Training loop parameters

  • learning_rate - Initial learning rate
  • max_train_steps - Maximum training steps
  • warmup_steps_fraction - Fraction of steps for learning rate warmup
  • output_dir - Directory to save checkpoints (supports GCS)
  • run_name - Unique identifier for this training run

Parallelism parameters

  • ici_data_parallelism - Data parallelism within a host
  • ici_fsdp_parallelism - FSDP parallelism within a host
  • ici_tensor_parallelism - Tensor parallelism within a host
  • dcn_data_parallelism - Data parallelism across hosts

Profiling and checkpointing

  • enable_profiler - Enable performance profiling
  • checkpoint_every - Save checkpoint every N steps (-1 to disable)
  • jax_cache_dir - Directory for JAX compilation cache

Getting started

For your first time running MaxDiffusion training, we recommend:
  1. Start with a single TPU host before scaling to multi-host
  2. Use the default Pokemon dataset for initial testing
  3. Review the model-specific training guide for detailed instructions

Stable Diffusion training

Train SD 1.4 and SD 2 Base models

SDXL training

Fine-tune Stable Diffusion XL

Flux training

Train Flux transformer models

Wan training

Train video generation models

Dreambooth

Personalized model fine-tuning

XPK deployment

For large-scale training, MaxDiffusion supports deployment via XPK (Kubernetes-based orchestration):
python3 ~/xpk/xpk.py workload create \
  --cluster=$CLUSTER_NAME \
  --project=$PROJECT \
  --zone=$ZONE \
  --device-type=$DEVICE_TYPE \
  --num-slices=1 \
  --command="python src/maxdiffusion/train.py ..." \
  --base-docker-image=$IMAGE_DIR \
  --workload=$RUN_NAME
See the model-specific guides for complete XPK examples.

Build docs developers (and LLMs) love