Skip to main content

Installation

MaxDiffusion can be installed on Cloud TPUs, GPUs, or local machines. We recommend starting with a single TPU host for development before scaling to multi-host configurations.

System requirements

  • Operating System: Ubuntu 22.04 or later
  • Python: 3.12 or later
  • TensorFlow: 2.12.0 or later
  • Hardware: Cloud TPU v4/v5p/v6e or NVIDIA GPU with CUDA support

Quick install

1

Create a TPU VM

Create and SSH into a single-host TPU VM:
gcloud compute tpus tpu-vm create maxdiffusion-vm \
  --zone=us-central2-b \
  --accelerator-type=v6e-8 \
  --version=v2-alpha-tpuv6e

gcloud compute tpus tpu-vm ssh maxdiffusion-vm --zone=us-central2-b
For TPU v6e (Trillium), use the v2-alpha-tpuv6e VM image which includes Ubuntu 22.04, Python 3.12, and TensorFlow 2.12+.
2

Clone the repository

git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
cd maxdiffusion
3

Run setup script

Install MaxDiffusion and dependencies:
bash setup.sh MODE=stable DEVICE=tpu
If using Python 3.12+ for the first time, you may need to run the setup script three times to complete the installation.
4

Activate virtual environment

venv_name="maxdiffusion_venv"
source ~/$venv_name/bin/activate
5

Verify installation

Test your installation by generating an image:
python -m src.maxdiffusion.generate \
  src/maxdiffusion/configs/base21.yml \
  run_name="test" \
  prompt="A magical castle in the middle of a forest"

TPU VM images

For different TPU generations, use the appropriate VM image:
TPU TypeRecommended VM ImageZones
v6e-8 (Trillium)v2-alpha-tpuv6eus-central2-b, us-south1-a
v5p-8v2-alpha-tpuv5-liteus-east5-a, us-west4-a
v4-8tpu-ubuntu2204-baseus-central2-b, europe-west4-a
See the Cloud TPU regions and zones documentation for the full list of available zones.

Development installation

For contributing to MaxDiffusion or modifying the source code:
git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
cd maxdiffusion

# Install in editable mode
pip install -e .

# Install development dependencies
pip install pytest pylint pyink

Optional dependencies

For training

pip install accelerate datasets tensorboard

For video models

pip install opencv-python
apt-get update && apt-get install ffmpeg libsm6 libxext6 -y

For testing

pip install pytest pytest-timeout pytest-xdist

Troubleshooting

Make sure JAX is installed for your target device:
# For TPU
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# For GPU
pip install -U "jax[cuda12]"
MaxDiffusion requires Python 3.12+. Check your version:
python --version
If using an older version, create a new virtual environment with Python 3.12.
MaxDiffusion models can be large. Consider attaching an external disk:
# Create and attach a 500GB disk
gcloud compute disks create maxdiffusion-disk \
  --size=500GB \
  --zone=us-central2-b

# Attach to your TPU VM
# See: https://cloud.google.com/tpu/docs/attach-durable-block-storage
For Python 3.12+ virtual environments, you may need to run the setup script multiple times:
bash setup.sh MODE=stable DEVICE=tpu
bash setup.sh MODE=stable DEVICE=tpu
bash setup.sh MODE=stable DEVICE=tpu

Next steps

Quickstart

Generate your first image

Single host deployment

Set up local development

Multi-host deployment

Scale to TPU Pods

Training guide

Start training models

Build docs developers (and LLMs) love