Overview
Thegenerate.py script provides inference capabilities for Stable Diffusion 2.x models (SD 2 base and SD 2.1). It supports single and multi-host deployment with JAX sharding annotations.
Usage
Basic command
With custom checkpoint
Configuration parameters
Model parameters
Path or identifier for the pretrained model. Use
stabilityai/stable-diffusion-2-base for SD 2 base or stabilityai/stable-diffusion-2-1 for SD 2.1.Model revision to use from HuggingFace.
Set to
true to load weights from PyTorch format.Path to a specific UNet checkpoint to load.
Generation parameters
Text prompt describing the image to generate.
Negative prompt to guide what should not appear in the image.
Number of denoising steps. Higher values generally produce better quality but take longer.
Classifier-free guidance scale. Higher values produce images more closely aligned with the prompt.
Guidance rescale factor based on Common Diffusion Noise Schedules. Helps solve overexposure when terminal SNR approaches zero. Recommended value is 0.7 with guidance_scale=7.5.
Output image resolution in pixels (width and height).
Performance parameters
Data type for model weights. Options:
float32, bfloat16, float16.Data type for layer activations. Options:
float32, bfloat16, float16.Attention mechanism to use. Options:
dot_product, flash.Matmul and conv precision. Options:
DEFAULT, HIGH, HIGHEST. FP32 with HIGHEST provides best precision at the cost of speed.Parallelism parameters
Data parallelism across ICI devices. Use -1 for auto-sharding.
FSDP parallelism across ICI devices.
Tensor parallelism across ICI devices.
Data parallelism across DCN slices. Use -1 for auto-sharding.
Batch size per device.
System parameters
Unique identifier for the inference run.
Directory to save generated images.
Random seed for reproducible generation.
Directory for JAX compilation cache.
Hardware type. Options:
tpu, gpu.Output
Generated images are saved as PNG files in the current directory with the naming patternimage_{i}.png, where i is the image index in the batch.
Implementation details
The script:- Loads the Stable Diffusion pipeline and weights
- Creates inference states for UNet, VAE, and text encoder
- Sets up DDIM scheduler for the denoising process
- Tokenizes and encodes the text prompts
- Runs the denoising loop for the specified number of steps
- Decodes latents to images using the VAE
- Saves images to disk
~/workspace/source/src/maxdiffusion/generate.py