Quick start
Architecture
Stable Diffusion 2 uses:- UNet: Noise prediction model with cross-attention layers
- Text encoder: OpenCLIP ViT-H/14 for text embeddings
- VAE: Latent autoencoder for image encoding/decoding (8x compression)
- Scheduler: DDIM scheduler for iterative denoising
Configuration
The base configuration files define model architecture and default parameters:base_2_base.yml: Stable Diffusion 2 base configurationbase21.yml: Stable Diffusion 2.1 configuration
Custom generation
Override configuration values via command line:Parameters
| Parameter | Description | Default |
|---|---|---|
prompt | Text description of desired image | Required |
negative_prompt | Concepts to avoid | Empty |
num_inference_steps | Denoising steps (more = higher quality) | 50 |
guidance_scale | CFG strength (higher = more prompt adherence) | 7.5 |
guidance_rescale | Rescale noise to prevent overexposure | 0.0 |
resolution | Output image resolution | 512 |
per_device_batch_size | Images per device | 1 |
seed | Random seed | 0 |
Guidance rescale
Based on Common Diffusion Noise Schedules (section 3.4), guidance rescale helps solve overexposure when terminal SNR approaches zero. Recommended values:guidance_scale=7.5guidance_rescale=0.7
Implementation details
The inference pipeline (generate.py:src/maxdiffusion/generate.py) implements:
- Text encoding: Tokenize prompt and encode with CLIP text encoder
- Latent initialization: Sample random noise scaled by scheduler sigma
- Denoising loop: Iteratively denoise latents with UNet predictions
- CFG: Combine conditional and unconditional predictions
- VAE decode: Decode final latents to pixel space
Denoising loop
The core denoising iteration (generate.py:48-78):
Sharding
Stable Diffusion 2 supports single and multi-host inference with data parallelism. The model components are sharded according tological_axis_rules in the config.
For multi-device inference, increase batch size:
Output
Generated images are saved asimage_{i}.png in the current directory. The pipeline reports:
- Compile time: JAX compilation duration
- Inference time: Generation time after compilation
Checkpointing
Load custom checkpoints by specifying:Next steps
SDXL inference
Higher quality with Stable Diffusion XL
ControlNet
Conditional generation with edge maps
Dreambooth training
Fine-tune on custom subjects
Configuration
Full configuration reference