Skip to main content

Overview

The CIFAR-10 diffusion model uses a significantly wider and deeper U-Net compared to MNIST. It incorporates dropout regularization, multi-scale self-attention, and careful architectural choices to handle the complexity of 32×32 RGB natural images.

Model specification

The CIFAR-10 U-Net is defined in src/models/diffusion_cifar.py as DiffusionModelCIFAR:
class DiffusionModelCIFAR(DiffusionModel):
    def __init__(self, image_size, channels, hidden_dims=[128, 256, 256, 256], 
                 time_dim=128, dropout_p=0.1):
        super().__init__(image_size, channels, hidden_dims, time_dim)

Architecture parameters

ParameterValueDescription
image_size32Input image dimensions (32×32)
channels3RGB color images
hidden_dims[128, 256, 256, 256]Channel counts at each resolution
time_dim128Time embedding dimension
dropout_p0.1Dropout probability
The hidden dimensions create four resolution levels: 32×32 → 16×16 → 8×8 → 4×4, with attention applied selectively at the 16×16 resolution.

Key architectural enhancements

1. Residual blocks with dropout

CIFAR-10 uses an enhanced ResBlock with dropout to prevent overfitting:
class ResBlockWithDropout(ResBlock):
    def __init__(self, in_ch, out_ch, time_dim, dropout_p=0.02):
        super().__init__(in_ch, out_ch, time_dim)
        self.dropout = nn.Dropout2d(p=dropout_p)
    
    def forward(self, x, t_emb):
        h = self.block1(x)
        h = self.dropout(h)  # Dropout after first conv
        h = h + self.time_emb(t_emb)[:, :, None, None]
        h = self.block2(h)
        return self.shortcut(x) + h
The dropout is applied after the first convolution block but before the time embedding injection.
Dropout rates vary by location: 0.02 in residual blocks, 0.05 in bottleneck attention blocks, and 0.1 in down/up blocks for stronger regularization.

2. Multi-scale attention

Unlike MNIST which only uses attention in the bottleneck, CIFAR-10 applies self-attention at specific resolutions:
# Attention is used at the 16×16 resolution (index 1)
attention_resolutions = [1]

self.down_blocks = nn.ModuleList([
    DownBlockWithAttention(
        hidden_dims[i], 
        hidden_dims[i + 1],
        time_dim, 
        dropout_p,
        use_attention=(i in attention_resolutions),  # Only at 16×16
    )
    for i in range(len(hidden_dims) - 1)
])
Attention placement:
  • 32×32 (index 0): No attention (too expensive)
  • 16×16 (index 1): Self-attention enabled
  • 8×8 (index 2): No attention
  • 4×4 (bottleneck): Self-attention enabled
Applying attention at 16×16 captures mid-level features like object parts and textures, while bottleneck attention at 4×4 captures global structure.

3. Down blocks with attention

class DownBlockWithAttention(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, dropout_p=0.05, use_attention=False):
        super().__init__()
        self.res = ResBlockWithDropout(in_ch, out_ch, time_dim, dropout_p)
        self.attention = SelfAttention(out_ch) if use_attention else nn.Identity()
        self.pool = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1)
    
    def forward(self, x, t_emb):
        h = self.res(x, t_emb)
        h = self.attention(h)  # Attention before downsampling
        return self.pool(h), h
Attention is applied after the residual block but before downsampling to preserve high-resolution features.

4. Up blocks with attention

class UpBlockWithAttention(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, time_dim, dropout_p=0.05, use_attention=False):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.res = ResBlockWithDropout(out_ch + skip_ch, out_ch, time_dim, dropout_p)
        self.attention = SelfAttention(out_ch) if use_attention else nn.Identity()
    
    def forward(self, x, skip, t_emb):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.res(x, t_emb)
        x = self.attention(x)  # Attention after skip fusion
        return x

5. Enhanced bottleneck

The bottleneck uses multiple attention layers for stronger feature learning:
class BottleneckWithAttention(nn.Module):
    def __init__(self, ch, time_dim, dropout_p=0.05):
        super().__init__()
        self.res1 = ResBlockWithDropout(ch, ch, time_dim, dropout_p)
        self.attention = SelfAttention(ch)
        self.res2 = ResBlockWithDropout(ch, ch, time_dim, dropout_p)
    
    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.attention(x)
        x = self.res2(x, t_emb)
        return x

Network structure

Encoder path

self.down_blocks = nn.ModuleList([
    DownBlockWithAttention(128, 256, 128, 0.1, use_attention=False),  # 32→16, no attn
    DownBlockWithAttention(256, 256, 128, 0.1, use_attention=True),   # 16→8, WITH attn
    DownBlockWithAttention(256, 256, 128, 0.1, use_attention=False),  # 8→4, no attn
])
Channel and resolution progression:
  1. 32×32 @ 128 channels → 16×16 @ 256 channels
  2. 16×16 @ 256 channels → 8×8 @ 256 channels (with attention)
  3. 8×8 @ 256 channels → 4×4 @ 256 channels

Bottleneck

self.bottleneck = BottleneckWithAttention(256, 128, dropout_p=0.1)
Operates at 4×4 @ 256 channels with self-attention.

Decoder path

self.up_blocks = nn.ModuleList([
    UpBlockWithAttention(256, 256, 256, 128, 0.1, use_attention=False),  # 4→8
    UpBlockWithAttention(256, 256, 256, 128, 0.1, use_attention=True),   # 8→16, WITH attn
    UpBlockWithAttention(256, 256, 128, 128, 0.1, use_attention=False),  # 16→32
])
Mirrors the encoder’s attention pattern for symmetry.

Parameter count

The CIFAR-10 model is significantly larger than MNIST:
  • Total parameters: ~12.5M
  • Time embedding: ~132K parameters
  • Encoder: ~4.2M parameters
  • Bottleneck: ~2.8M parameters
  • Decoder: ~5.1M parameters
  • Output layers: ~9K parameters
This is approximately 10× larger than the MNIST model, reflecting the increased complexity of natural images.

Training configuration

The CIFAR-10 model uses AdamW with weight decay for better generalization:
self.optimizer = torch.optim.AdamW(
    self.model.parameters(),
    lr=2e-4,
    weight_decay=1e-5,
    betas=(0.9, 0.999),
)
Additional training features:
  • Gradient clipping: max_norm=1.0 to prevent exploding gradients
  • Linear beta schedule instead of cosine
  • Exponential Moving Average (EMA) for sampling

Noise schedule

CIFAR-10 uses a linear schedule instead of cosine:
self.beta_schedule = torch.linspace(
    beta_start, beta_end, noise_steps, device=self.device
)
With beta_start=1e-4 and beta_end=0.02 over 1000 steps.
The linear schedule is preferred for CIFAR-10 as it matches the original DDPM paper and provides more aggressive early denoising for complex images.

Exponential Moving Average (EMA)

CIFAR-10 maintains an EMA copy of the model weights for sampling:
# EMA copy used only for sampling
self.ema_model = DiffusionModelCIFAR(...)
self.ema_model.load_state_dict(self.model.state_dict())
self.ema_model.eval()
self.ema_decay = 0.999

# Update after each training step
with torch.no_grad():
    for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()):
        ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay)
The EMA model is used for generation instead of the training model, providing more stable and higher-quality samples.

Comparison with MNIST

AspectMNISTCIFAR-10
Hidden dims[32, 64, 128][128, 256, 256, 256]
Parameters~1.2M~12.5M
AttentionBottleneck onlyBottleneck + 16×16
DropoutNone0.02-0.1 (varied)
OptimizerAdamAdamW + weight decay
Beta scheduleCosineLinear
EMANoYes (decay=0.999)
Grad clippingNoYes (max_norm=1.0)

Usage example

from src.models.diffusion_cifar import DiffusionProcessCIFAR

# Initialize with CIFAR-10 parameters
diffusion = DiffusionProcessCIFAR(
    image_size=32,
    channels=3,
    hidden_dims=[128, 256, 256, 256],
    noise_steps=1000,
    dropout_p=0.1,
    ema_decay=0.999,
    device=torch.device('cuda')
)

# Train on CIFAR-10 data
for batch in cifar_loader:
    loss = diffusion.train_step(batch)

# Generate samples using EMA model
samples = diffusion.sample(num_samples=64)
See the CIFAR-10 training guide for complete training code and hyperparameter tuning recommendations.

Build docs developers (and LLMs) love