Skip to main content

Overview

The CIFAR-10 diffusion classes extend the base implementation with architectural improvements for color image generation:
  • DiffusionProcessCIFAR: Enhanced training with AdamW optimizer, gradient clipping, EMA (Exponential Moving Average), and linear beta schedule
  • DiffusionModelCIFAR: Wider U-Net with dropout, configurable attention at specific resolutions, and multiple residual blocks per level
These classes maintain the same mathematical foundations as the base DDPM implementation while adding architectural refinements for better CIFAR-10 performance.

DiffusionProcessCIFAR

Constructor

DiffusionProcessCIFAR(
    image_size=32,
    channels=3,
    hidden_dims=[128, 256, 256, 256],
    beta_start=1e-4,
    beta_end=2e-2,
    noise_steps=1000,
    dropout_p=0.1,
    ema_decay=0.999,
    device=None
)

Parameters

image_size
int
default:"32"
Height and width of square images (CIFAR-10 uses 32x32).
channels
int
default:"3"
Number of image channels (3 for RGB images).
hidden_dims
list[int]
default:"[128, 256, 256, 256]"
Channel dimensions for each U-Net level. CIFAR-10 uses wider networks with base channels of 128 and multipliers [1, 2, 2, 2].
beta_start
float
default:"1e-4"
Initial noise variance in the linear noise schedule.
beta_end
float
default:"2e-2"
Final noise variance. Uses 0.02 following standard DDPM CIFAR-10 implementations.
noise_steps
int
default:"1000"
Total number of diffusion timesteps.
dropout_p
float
default:"0.1"
Dropout probability applied in residual blocks to reduce overfitting.
ema_decay
float
default:"0.999"
EMA decay rate for maintaining an exponential moving average of model weights. The EMA model is used for sampling to improve image quality.
device
torch.device
default:"None"
Device for computations. Defaults to CUDA if available, otherwise CPU.

Key differences from base class

Linear beta schedule: Uses a linear schedule torch.linspace(beta_start, beta_end, noise_steps) instead of cosine schedule. This matches standard DDPM CIFAR-10 implementations. Posterior variance coefficients: Precomputes posterior_variance, posterior_mean_coef1, and posterior_mean_coef2 for the posterior distribution q(x_ | x_t, x_0), improving sampling accuracy. AdamW optimizer: Uses AdamW with learning rate 2e-4, weight decay 1e-5, and betas (0.9, 0.999) instead of Adam. EMA model: Maintains a separate EMA copy of the model that is updated after each training step and used for sampling. Gradient clipping: Clips gradients to max norm 1.0 to stabilize training of the larger network.

Attributes

model
DiffusionModelCIFAR
Main training model with dropout.
ema_model
DiffusionModelCIFAR
Exponential moving average model used for sampling. Updated with ema_decay rate after each training step.
optimizer
torch.optim.AdamW
AdamW optimizer with lr=2e-4, weight_decay=1e-5.
posterior_variance
torch.Tensor
Precomputed posterior variance for DDPM sampling.
posterior_mean_coef1
torch.Tensor
Coefficient for x_0 term in posterior mean computation.
posterior_mean_coef2
torch.Tensor
Coefficient for x_t term in posterior mean computation.

Methods

add_noise

Identical to base class implementation. Adds noise according to the forward diffusion process.
def add_noise(self, x, t)
See DiffusionProcess.add_noise for details.

train_step

Enhanced training step with gradient clipping and EMA updates.
def train_step(self, x)
Parameters:
x
torch.Tensor
required
Clean images tensor of shape [batch_size, channels, height, width].
Returns:
loss
float
MSE loss between predicted and actual noise.
Implementation details:
  1. Samples random timesteps and adds noise
  2. Predicts noise using the main model (not EMA)
  3. Computes MSE loss with optional mixed precision
  4. Clips gradients to max norm 1.0 (unscaled for AMP)
  5. Updates model parameters
  6. Updates EMA model: ema_param ← ema_decay * ema_param + (1 - ema_decay) * param

sample

DDPM sampling using the EMA model and posterior variance.
def sample(self, num_samples=16)
Parameters:
num_samples
int
default:"16"
Number of images to generate.
Returns:
samples
torch.Tensor
Generated images of shape [num_samples, channels, image_size, image_size], values in [-1, 1].
Key differences from base class:
  • Uses ema_model instead of the training model
  • Reconstructs x_0 from predicted noise and x_t, then clamps to [-1, 1]
  • Computes posterior mean using precomputed coefficients posterior_mean_coef1 and posterior_mean_coef2
  • Uses posterior_variance instead of recomputing variance from beta

sample_ddim

DDIM sampling using the EMA model.
def sample_ddim(self, num_samples=16, ddim_steps=50, eta=0.0)
Parameters:
num_samples
int
default:"16"
Number of images to generate.
ddim_steps
int
default:"50"
Number of denoising steps. Must be in (0, noise_steps].
eta
float
default:"0.0"
Stochasticity parameter. 0 = deterministic, 1 = DDPM-like.
Returns:
samples
torch.Tensor
Generated images of shape [num_samples, channels, image_size, image_size].
Implementation:
  • Uses uniform grid of timesteps via torch.linspace(0, noise_steps-1, steps=ddim_steps)
  • Uses ema_model for all predictions
  • Clamps at the final step only (not at each intermediate step)

DiffusionModelCIFAR

Constructor

DiffusionModelCIFAR(
    image_size,
    channels,
    hidden_dims=[128, 256, 256, 256],
    time_dim=128,
    dropout_p=0.1
)

Parameters

image_size
int
required
Height and width of square input images.
channels
int
required
Number of image channels (typically 3 for CIFAR-10).
hidden_dims
list[int]
default:"[128, 256, 256, 256]"
Channel dimensions at each resolution level. CIFAR-10 uses 4 levels: 32×32, 16×16, 8×8, 4×4.
time_dim
int
default:"128"
Dimensionality of time embeddings.
dropout_p
float
default:"0.1"
Dropout probability in residual blocks.

Architecture features

Attention at 16×16 resolution: Self-attention is applied only at index 1 (16×16 resolution) in both encoder and decoder, plus in the bottleneck. This balances computation cost with modeling long-range dependencies. Dropout regularization: All residual blocks use ResBlockWithDropout, applying 2D dropout after the first convolution to reduce overfitting on CIFAR-10. Enhanced bottleneck: BottleneckWithAttention applies ResBlock → SelfAttention → ResBlock with dropout support. Standard U-Net forward pass: Inherits from DiffusionModel but uses the enhanced blocks with dropout and selective attention.

Methods

forward

Standard U-Net forward pass.
def forward(self, x, t)
Parameters:
x
torch.Tensor
required
Noisy images of shape [batch_size, channels, height, width].
t
torch.Tensor
required
Timesteps of shape [batch_size].
Returns:
noise_prediction
torch.Tensor
Predicted noise of shape [batch_size, channels, height, width].

Supporting classes

ResBlockWithDropout

Extends ResBlock with 2D dropout regularization.
ResBlockWithDropout(in_ch, out_ch, time_dim, dropout_p=0.02)
Structure:
  • GroupNorm → SiLU → Conv2d
  • Dropout2d(p=dropout_p)
  • Add time embedding
  • GroupNorm → SiLU → Conv2d
  • Skip connection

BottleneckWithAttention

Bottleneck block with dropout-enabled residual blocks.
BottleneckWithAttention(ch, time_dim, dropout_p=0.05)
Structure:
  • ResBlockWithDropout
  • SelfAttention
  • ResBlockWithDropout

DownBlockWithAttention

Downsampling block with optional attention.
DownBlockWithAttention(in_ch, out_ch, time_dim, dropout_p=0.05, use_attention=False)
Parameters:
  • use_attention (bool): Whether to apply self-attention after the residual block
Structure:
  • ResBlockWithDropout
  • SelfAttention (if use_attention=True)
  • Conv2d(4x4, stride=2) downsampling

UpBlockWithAttention

Upsampling block with optional attention.
UpBlockWithAttention(in_ch, skip_ch, out_ch, time_dim, dropout_p=0.05, use_attention=False)
Structure:
  • ConvTranspose2d(4x4, stride=2) upsampling
  • Concatenate skip connection
  • ResBlockWithDropout
  • SelfAttention (if use_attention=True)

Usage example

import torch
from models.diffusion_cifar import DiffusionProcessCIFAR

# Initialize CIFAR-10 diffusion process
diffusion = DiffusionProcessCIFAR(
    image_size=32,
    channels=3,
    hidden_dims=[128, 256, 256, 256],
    dropout_p=0.1,
    ema_decay=0.999
)

# Training loop
for epoch in range(num_epochs):
    for batch in cifar_dataloader:
        images = batch[0]  # Shape: [batch_size, 3, 32, 32]
        # Images should be normalized to [-1, 1]
        loss = diffusion.train_step(images)
        print(f"Loss: {loss:.4f}")

# Generate samples using EMA model (DDPM)
samples = diffusion.sample(num_samples=16)

# Faster sampling with DDIM
samples_fast = diffusion.sample_ddim(
    num_samples=16,
    ddim_steps=50,
    eta=0.0  # Deterministic
)

# Denormalize for visualization: [−1, 1] → [0, 1]
samples_vis = (samples + 1.0) / 2.0

Mathematical formulation

Despite architectural differences, DiffusionProcessCIFAR uses the same DDPM equations: Forward process:
q(x_t | x_0) = N(√ᾱ_t x_0, (1 - ᾱ_t) I)
x_t = √ᾱ_t x_0 + √(1 - ᾱ_t) ε
Training objective:
L = E_{t, x_0, ε} [||ε - ε_θ(x_t, t)||²]
Reverse process:
p_θ(x_{t-1} | x_t) = N(μ_θ(x_t, t), Σ_t)
μ_θ = (1/√α_t) * (x_t - (β_t/√(1-ᾱ_t)) * ε_θ(x_t, t))
The main differences are the linear (vs. cosine) beta schedule and the use of EMA parameters during sampling.

Build docs developers (and LLMs) love