Skip to main content
Multihost training allows you to scale MaxDiffusion across multiple TPU hosts for large-scale training workloads.

Overview

Multihost deployment enables distributed training by coordinating multiple TPU VMs to work together on the same training job. This approach is ideal when you need more compute power than a single host can provide.

Prerequisites

  • A TPU pod slice (multiple TPU VMs)
  • gcloud CLI configured with your project
  • MaxDiffusion repository cloned locally
  • GCS bucket for storing outputs

Setup and training

1

Set environment variables

Configure your TPU pod details:
TPU_NAME=<your-tpu-name>
ZONE=<your-zone>
PROJECT_ID=<your-project-id>
2

Run distributed training

Use gcloud to execute the training command across all workers in your TPU pod:
gcloud compute tpus tpu-vm ssh $TPU_NAME \
  --zone=$ZONE \
  --project $PROJECT_ID \
  --worker=all \
  --command="
  export LIBTPU_INIT_ARGS=''
  git clone https://github.com/google/maxdiffusion
  cd maxdiffusion
  pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  pip3 install -r requirements.txt
  pip3 install .
  python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
    run_name=my_run \
    output_dir=gs://your-bucket/"
This command:
  • SSHes into all workers simultaneously (--worker=all)
  • Clones MaxDiffusion on each worker
  • Installs dependencies on each worker
  • Launches coordinated training across all workers

Configuration options

Parallelism strategies

MaxDiffusion supports several parallelism strategies for multihost training:
  • Data parallelism (ici_data_parallelism) - Distribute different batches across devices
  • FSDP parallelism (ici_fsdp_parallelism) - Shard model parameters across devices
  • Tensor parallelism (ici_tensor_parallelism) - Split individual tensors across devices
Example configuration:
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml \
  run_name=my_run \
  output_dir=gs://your-bucket/ \
  ici_data_parallelism=32 \
  ici_fsdp_parallelism=4 \
  ici_tensor_parallelism=1

Environment variables

For optimal performance, set LIBTPU_INIT_ARGS with appropriate XLA flags:
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion=true \
  --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
  --xla_tpu_enable_async_collective_fusion_multiple_steps=true'

Monitoring

Training metrics are automatically logged to TensorBoard:
tensorboard --logdir=gs://your-bucket/run_name/tensorboard/

Best practices

Always validate your training configuration on a single host before scaling to multihost to catch configuration issues early.
Store all outputs (checkpoints, logs, datasets) in Google Cloud Storage for accessibility across all workers.
Ensure your per_device_batch_size multiplied by the number of devices results in a reasonable global batch size.

Next steps

For managed Kubernetes-based deployment with easier cluster management, see the XPK deployment guide.

Build docs developers (and LLMs) love