Skip to main content
MaxDiffusion supports model quantization to reduce memory usage and potentially improve inference speed on TPU hardware.

Overview

Quantization converts model weights from higher precision formats (like float32) to lower precision formats, reducing memory footprint while maintaining acceptable quality.

Configuration

Quantization is configured in the model’s YAML config file. The following parameters control quantization behavior:

Basic parameters

quantization: ''  # Empty string disables quantization
quantization_local_shard_count: -1  # -1 defaults to number of slices
use_qwix_quantization: False
compile_topology_num_slices: -1  # Number of target slices

Parameter descriptions

quantization
  • Type: string
  • Default: '' (empty, disabled)
  • Description: Specifies the quantization mode. When empty, quantization is disabled.
quantization_local_shard_count
  • Type: integer
  • Default: -1 (auto-detect from number of slices)
  • Description: Shards the range-finding operation for quantization across devices. By default, this is set to the number of slices in your TPU configuration.
use_qwix_quantization
  • Type: boolean
  • Default: False
  • Description: Enables QWIX (Quantization with Index) optimization for quantization.
compile_topology_num_slices
  • Type: integer
  • Default: -1 (auto-detect)
  • Description: Number of target slices for compilation topology. Set to a positive integer to override auto-detection.

Example configuration

From base_xl.yml:
quantization: ''
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1

Usage with training

When training models, you can enable quantization by setting the quantization parameter:
python -m src.maxdiffusion.train_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_xl_run" \
  output_dir="gs://your-bucket/" \
  per_device_batch_size=1 \
  quantization="your_quantization_mode"

Usage with inference

Quantization settings are loaded from the config file during inference:
python -m src.maxdiffusion.generate_sdxl \
  src/maxdiffusion/configs/base_xl.yml \
  run_name="my_run" \
  quantization="your_quantization_mode"

Multi-slice quantization

When running on multi-slice TPU pods, the quantization range-finding operation is automatically sharded across slices for efficiency. You can control this behavior with quantization_local_shard_count:
# Explicitly set shard count for a 4-slice pod
quantization_local_shard_count: 4

Performance considerations

Memory savings

  • Quantization can significantly reduce model memory footprint
  • Exact savings depend on the quantization mode and model architecture
  • Particularly beneficial for large models like SDXL and Flux

Quality impact

  • Lower precision formats may introduce minor quality degradation
  • Test different quantization modes to find the right balance for your use case
  • Some quantization modes preserve quality better than others

Inference speed

  • Quantized models may have faster inference on certain hardware
  • TPUs can benefit from reduced memory bandwidth requirements
  • Actual speedup varies by model and quantization type

Best practices

  1. Start without quantization: Establish baseline quality and performance
  2. Test incrementally: Try different quantization modes and compare results
  3. Monitor quality: Use consistent prompts to evaluate quality impact
  4. Profile performance: Measure actual memory and speed improvements
  5. Match topology: Set compile_topology_num_slices to match your deployment environment

Build docs developers (and LLMs) love