Skip to main content
MaxDiffusion supports inference for Stable Diffusion 2 base and Stable Diffusion 2.1 models, generating 512×512 resolution images.

Quick start

python -m src.maxdiffusion.generate \
  src/maxdiffusion/configs/base_2_base.yml \
  run_name="my_run"

Architecture

Stable Diffusion 2 uses:
  • UNet: Noise prediction model with cross-attention layers
  • Text encoder: OpenCLIP ViT-H/14 for text embeddings
  • VAE: Latent autoencoder for image encoding/decoding (8x compression)
  • Scheduler: DDIM scheduler for iterative denoising

Configuration

The base configuration files define model architecture and default parameters:
  • base_2_base.yml: Stable Diffusion 2 base configuration
  • base21.yml: Stable Diffusion 2.1 configuration

Custom generation

Override configuration values via command line:
python -m src.maxdiffusion.generate \
  src/maxdiffusion/configs/base21.yml \
  run_name="custom_generation" \
  prompt="a photograph of an astronaut riding a horse" \
  negative_prompt="blurry, low quality" \
  num_inference_steps=50 \
  guidance_scale=7.5 \
  per_device_batch_size=2 \
  seed=42

Parameters

ParameterDescriptionDefault
promptText description of desired imageRequired
negative_promptConcepts to avoidEmpty
num_inference_stepsDenoising steps (more = higher quality)50
guidance_scaleCFG strength (higher = more prompt adherence)7.5
guidance_rescaleRescale noise to prevent overexposure0.0
resolutionOutput image resolution512
per_device_batch_sizeImages per device1
seedRandom seed0

Guidance rescale

Based on Common Diffusion Noise Schedules (section 3.4), guidance rescale helps solve overexposure when terminal SNR approaches zero. Recommended values:
  • guidance_scale=7.5
  • guidance_rescale=0.7
python -m src.maxdiffusion.generate \
  src/maxdiffusion/configs/base21.yml \
  run_name="rescaled" \
  guidance_scale=7.5 \
  guidance_rescale=0.7

Implementation details

The inference pipeline (generate.py:src/maxdiffusion/generate.py) implements:
  1. Text encoding: Tokenize prompt and encode with CLIP text encoder
  2. Latent initialization: Sample random noise scaled by scheduler sigma
  3. Denoising loop: Iteratively denoise latents with UNet predictions
  4. CFG: Combine conditional and unconditional predictions
  5. VAE decode: Decode final latents to pixel space

Denoising loop

The core denoising iteration (generate.py:48-78):
def loop_body(step, args, model, pipeline, prompt_embeds, guidance_scale, guidance_rescale):
  latents, scheduler_state, state = args
  latents_input = jnp.concatenate([latents] * 2)  # For CFG
  
  t = scheduler_state.timesteps[step]
  latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t)
  
  # UNet forward pass
  noise_pred = model.apply(
      {"params": state.params},
      latents_input,
      timestep,
      encoder_hidden_states=prompt_embeds,
  ).sample
  
  # Apply classifier-free guidance
  noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
  noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
  
  # Optional guidance rescale
  noise_pred = jax.lax.cond(
      guidance_rescale[0] > 0,
      lambda _: rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale),
      lambda _: noise_pred,
      operand=None,
  )
  
  # Scheduler step
  latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
  return latents, scheduler_state, state

Sharding

Stable Diffusion 2 supports single and multi-host inference with data parallelism. The model components are sharded according to logical_axis_rules in the config. For multi-device inference, increase batch size:
python -m src.maxdiffusion.generate \
  src/maxdiffusion/configs/base21.yml \
  run_name="multi_device" \
  per_device_batch_size=4

Output

Generated images are saved as image_{i}.png in the current directory. The pipeline reports:
  • Compile time: JAX compilation duration
  • Inference time: Generation time after compilation

Checkpointing

Load custom checkpoints by specifying:
pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" \
revision="main" \
from_pt=True

Next steps

SDXL inference

Higher quality with Stable Diffusion XL

ControlNet

Conditional generation with edge maps

Dreambooth training

Fine-tune on custom subjects

Configuration

Full configuration reference

Build docs developers (and LLMs) love