Skip to main content
XPK (Accelerated Processing Kit) simplifies running MaxDiffusion on Google Kubernetes Engine (GKE) for both experimentation and production workloads.

Prerequisites

Verify you have these permissions for your account or service account:
  • Storage Admin
  • Kubernetes Engine Admin

Setup XPK

1

Install system dependencies

Install kubectl and gke-gcloud-auth-plugin:
sudo apt-get update
sudo apt install snapd
sudo snap install kubectl --classic
Install the GKE authentication plugin:
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list

curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -

sudo apt update && sudo apt-get install google-cloud-sdk-gke-gcloud-auth-plugin
2

Authenticate gcloud

Authenticate your gcloud installation:
gcloud auth login
Configure Docker to use gcloud credentials:
gcloud auth configure-docker us-docker.pkg.dev
Test the Docker installation:
docker run hello-world
If you get a permission error, run:
sudo usermod -aG docker $USER
Then log out and log back in to the machine.
3

Install XPK

Install XPK using pip:
pip install xpk
Alternatively, clone the XPK repository:
git clone https://github.com/google/xpk.git

Build Docker image

1

Clone MaxDiffusion

Clone the MaxDiffusion repository:
git clone https://github.com/google/MaxDiffusion.git
cd MaxDiffusion
2

Build dependency image

Build the MaxDiffusion base image. This only needs to be rerun when you change dependencies:
# Default will pick stable versions of dependencies
bash docker_build_dependency_image.sh

Using JAX AI Images

Build the MaxDiffusion Docker image using JAX AI base images for a more reliable build environment:
bash docker_build_dependency_image.sh \
  MODE=jax_ai_image \
  BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.5.2-rev2
Find available JAX AI base images at us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu.
JAX AI Images is currently in the experimental phase.

Run workloads

After building the maxdiffusion_base_image, XPK can handle updates to the working directory when running workloads.
When using XPK, include pip install . in your command to install the package from the current directory. This ensures local changes are applied within the container.

Configure environment

gcloud config set project $PROJECT_ID
gcloud config set compute/zone $ZONE

# Create GCS buckets for outputs and datasets
BASE_OUTPUT_DIR=gs://output_bucket/
DATASET_PATH=gs://dataset_bucket/

Create workload

xpk workload create \
  --cluster ${CLUSTER_NAME} \
  --base-docker-image maxDiffusion_base_image \
  --workload ${USER}-first-job \
  --tpu-type=v4-8 \
  --num-slices=1 \
  --command "pip install . && python src/maxdiffusion/train.py \
    src/maxdiffusion/configs/base_2_base.yml \
    run_name='my_run' \
    output_dir='gs://your-bucket/'"

Advanced usage

Large-scale training example

For large-scale Wan 2.1 training on v5p-256:
RUN_NAME=wan-training-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/tfrecords_dataset/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/tfrecords_dataset/eval_timesteps/

LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
  --xla_tpu_megacore_fusion_allow_ags=false \
  --xla_enable_async_collective_permute=true'

python3 ~/xpk/xpk.py workload create \
  --cluster=$CLUSTER_NAME \
  --project=$PROJECT \
  --zone=$ZONE \
  --device-type=v5p-256 \
  --num-slices=1 \
  --command="HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
    python src/maxdiffusion/train_wan.py \
    src/maxdiffusion/configs/base_wan_14b.yml \
    attention='flash' \
    run_name=${RUN_NAME} \
    output_dir=${OUTPUT_DIR} \
    train_data_dir=${DATASET_DIR} \
    per_device_batch_size=0.25 \
    ici_data_parallelism=32 \
    ici_fsdp_parallelism=4 \
    max_train_steps=5000" \
  --base-docker-image=${IMAGE_DIR} \
  --workload=${RUN_NAME}

Resources

Build docs developers (and LLMs) love