Supported models
MaxDiffusion supports training the following models:- Stable Diffusion 1.4 - Text-to-image generation
- Stable Diffusion 2 Base - Text-to-image generation with improved quality
- Stable Diffusion XL (SDXL) - High-resolution text-to-image generation at 1024x1024
- Flux Dev - Advanced text-to-image model with transformer architecture
- Wan 2.1 - Video generation (text-to-video and image-to-video)
- Dreambooth - Personalized fine-tuning for Stable Diffusion 1.x and 2.x
Hardware requirements
Minimum requirements
- Ubuntu 22.04
- Python 3.12
- TensorFlow >= 2.12.0
Supported accelerators
- TPU: v5p, v5e, v6e (Trillium)
- GPU: NVIDIA GPUs with CUDA support
Training workflow
The typical training workflow consists of:Prepare your dataset
Organize your dataset with images and captions. MaxDiffusion supports HuggingFace datasets, TFRecords, and local directories.
Configure training parameters
Choose a base config file from
src/maxdiffusion/configs/ and override parameters as needed.Common training parameters
All training scripts share common configuration parameters:Model parameters
pretrained_model_name_or_path- Base model to fine-tuneweights_dtype- Weight precision (float32, bfloat16)activations_dtype- Activation precision (float32, bfloat16)attention- Attention mechanism (dot_product, flash, cudnn_flash_te)
Dataset parameters
dataset_name- HuggingFace dataset nametrain_data_dir- Local or GCS path to training dataresolution- Training image resolutionper_device_batch_size- Batch size per device
Training loop parameters
learning_rate- Initial learning ratemax_train_steps- Maximum training stepswarmup_steps_fraction- Fraction of steps for learning rate warmupoutput_dir- Directory to save checkpoints (supports GCS)run_name- Unique identifier for this training run
Parallelism parameters
ici_data_parallelism- Data parallelism within a hostici_fsdp_parallelism- FSDP parallelism within a hostici_tensor_parallelism- Tensor parallelism within a hostdcn_data_parallelism- Data parallelism across hosts
Profiling and checkpointing
enable_profiler- Enable performance profilingcheckpoint_every- Save checkpoint every N steps (-1 to disable)jax_cache_dir- Directory for JAX compilation cache
Getting started
For your first time running MaxDiffusion training, we recommend:- Start with a single TPU host before scaling to multi-host
- Use the default Pokemon dataset for initial testing
- Review the model-specific training guide for detailed instructions
Stable Diffusion training
Train SD 1.4 and SD 2 Base models
SDXL training
Fine-tune Stable Diffusion XL
Flux training
Train Flux transformer models
Wan training
Train video generation models
Dreambooth
Personalized model fine-tuning