XPK (Accelerated Processing Kit) simplifies running MaxDiffusion on Google Kubernetes Engine (GKE) for both experimentation and production workloads.
Prerequisites
Verify you have these permissions for your account or service account:
- Storage Admin
- Kubernetes Engine Admin
Setup XPK
Install system dependencies
Install kubectl and gke-gcloud-auth-plugin:sudo apt-get update
sudo apt install snapd
sudo snap install kubectl --classic
Install the GKE authentication plugin:echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
sudo apt update && sudo apt-get install google-cloud-sdk-gke-gcloud-auth-plugin
Authenticate gcloud
Authenticate your gcloud installation:Configure Docker to use gcloud credentials:gcloud auth configure-docker us-docker.pkg.dev
Test the Docker installation:If you get a permission error, run:sudo usermod -aG docker $USER
Then log out and log back in to the machine. Install XPK
Install XPK using pip:Alternatively, clone the XPK repository:git clone https://github.com/google/xpk.git
Build Docker image
Clone MaxDiffusion
Clone the MaxDiffusion repository:git clone https://github.com/google/MaxDiffusion.git
cd MaxDiffusion
Build dependency image
Build the MaxDiffusion base image. This only needs to be rerun when you change dependencies:# Default will pick stable versions of dependencies
bash docker_build_dependency_image.sh
Using JAX AI Images
Build the MaxDiffusion Docker image using JAX AI base images for a more reliable build environment:bash docker_build_dependency_image.sh \
MODE=jax_ai_image \
BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.5.2-rev2
JAX AI Images is currently in the experimental phase.
Run workloads
After building the maxdiffusion_base_image, XPK can handle updates to the working directory when running workloads.
When using XPK, include pip install . in your command to install the package from the current directory. This ensures local changes are applied within the container.
gcloud config set project $PROJECT_ID
gcloud config set compute/zone $ZONE
# Create GCS buckets for outputs and datasets
BASE_OUTPUT_DIR=gs://output_bucket/
DATASET_PATH=gs://dataset_bucket/
Create workload
Using pip-installed XPK
Using XPK repository
xpk workload create \
--cluster ${CLUSTER_NAME} \
--base-docker-image maxDiffusion_base_image \
--workload ${USER}-first-job \
--tpu-type=v4-8 \
--num-slices=1 \
--command "pip install . && python src/maxdiffusion/train.py \
src/maxdiffusion/configs/base_2_base.yml \
run_name='my_run' \
output_dir='gs://your-bucket/'"
python3 xpk/xpk.py workload create \
--cluster ${CLUSTER_NAME} \
--base-docker-image maxDiffusion_base_image \
--workload ${USER}-first-job \
--tpu-type=v4-8 \
--num-slices=1 \
--command "pip install . && python src/maxdiffusion/train.py \
src/maxdiffusion/configs/base_2_base.yml \
run_name='my_run' \
output_dir='gs://your-bucket/'"
Advanced usage
Large-scale training example
For large-scale Wan 2.1 training on v5p-256:
RUN_NAME=wan-training-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/tfrecords_dataset/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/tfrecords_dataset/eval_timesteps/
LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true'
python3 ~/xpk/xpk.py workload create \
--cluster=$CLUSTER_NAME \
--project=$PROJECT \
--zone=$ZONE \
--device-type=v5p-256 \
--num-slices=1 \
--command="HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
run_name=${RUN_NAME} \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
per_device_batch_size=0.25 \
ici_data_parallelism=32 \
ici_fsdp_parallelism=4 \
max_train_steps=5000" \
--base-docker-image=${IMAGE_DIR} \
--workload=${RUN_NAME}
Resources