Skip to main content

Overview

The generate.py script provides inference capabilities for Stable Diffusion 2.x models (SD 2 base and SD 2.1). It supports single and multi-host deployment with JAX sharding annotations.

Usage

Basic command

python -m src.maxdiffusion.generate \
  src/maxdiffusion/configs/base_2_base.yml \
  run_name="my_run"

With custom checkpoint

python -m src.maxdiffusion.generate \
  src/maxdiffusion/configs/base_2_base.yml \
  run_name="my_run" \
  pretrained_model_name_or_path=<your_saved_checkpoint_path> \
  from_pt=False \
  attention=dot_product

Configuration parameters

Model parameters

pretrained_model_name_or_path
string
default:"stabilityai/stable-diffusion-2-1"
Path or identifier for the pretrained model. Use stabilityai/stable-diffusion-2-base for SD 2 base or stabilityai/stable-diffusion-2-1 for SD 2.1.
revision
string
default:"bf16"
Model revision to use from HuggingFace.
from_pt
boolean
default:false
Set to true to load weights from PyTorch format.
unet_checkpoint
string
default:""
Path to a specific UNet checkpoint to load.

Generation parameters

prompt
string
Text prompt describing the image to generate.
negative_prompt
string
default:"purple, red"
Negative prompt to guide what should not appear in the image.
num_inference_steps
integer
default:30
Number of denoising steps. Higher values generally produce better quality but take longer.
guidance_scale
float
Classifier-free guidance scale. Higher values produce images more closely aligned with the prompt.
guidance_rescale
float
Guidance rescale factor based on Common Diffusion Noise Schedules. Helps solve overexposure when terminal SNR approaches zero. Recommended value is 0.7 with guidance_scale=7.5.
resolution
integer
default:768
Output image resolution in pixels (width and height).

Performance parameters

weights_dtype
string
default:"float32"
Data type for model weights. Options: float32, bfloat16, float16.
activations_dtype
string
default:"bfloat16"
Data type for layer activations. Options: float32, bfloat16, float16.
attention
string
default:"dot_product"
Attention mechanism to use. Options: dot_product, flash.
precision
string
default:"DEFAULT"
Matmul and conv precision. Options: DEFAULT, HIGH, HIGHEST. FP32 with HIGHEST provides best precision at the cost of speed.

Parallelism parameters

ici_data_parallelism
integer
Data parallelism across ICI devices. Use -1 for auto-sharding.
ici_fsdp_parallelism
integer
default:1
FSDP parallelism across ICI devices.
ici_tensor_parallelism
integer
default:1
Tensor parallelism across ICI devices.
dcn_data_parallelism
integer
Data parallelism across DCN slices. Use -1 for auto-sharding.
per_device_batch_size
integer
default:1
Batch size per device.

System parameters

run_name
string
required
Unique identifier for the inference run.
output_dir
string
default:"sd-model-finetuned"
Directory to save generated images.
seed
integer
default:0
Random seed for reproducible generation.
jax_cache_dir
string
default:""
Directory for JAX compilation cache.
hardware
string
default:"tpu"
Hardware type. Options: tpu, gpu.

Output

Generated images are saved as PNG files in the current directory with the naming pattern image_{i}.png, where i is the image index in the batch.

Implementation details

The script:
  1. Loads the Stable Diffusion pipeline and weights
  2. Creates inference states for UNet, VAE, and text encoder
  3. Sets up DDIM scheduler for the denoising process
  4. Tokenizes and encodes the text prompts
  5. Runs the denoising loop for the specified number of steps
  6. Decodes latents to images using the VAE
  7. Saves images to disk
See source: ~/workspace/source/src/maxdiffusion/generate.py

Build docs developers (and LLMs) love