Overview
The CIFAR-10 diffusion model uses a significantly wider and deeper U-Net compared to MNIST. It incorporates dropout regularization, multi-scale self-attention, and careful architectural choices to handle the complexity of 32×32 RGB natural images.
Model specification
The CIFAR-10 U-Net is defined in src/models/diffusion_cifar.py as DiffusionModelCIFAR:
class DiffusionModelCIFAR(DiffusionModel):
def __init__(self, image_size, channels, hidden_dims=[128, 256, 256, 256],
time_dim=128, dropout_p=0.1):
super().__init__(image_size, channels, hidden_dims, time_dim)
Architecture parameters
| Parameter | Value | Description |
|---|
image_size | 32 | Input image dimensions (32×32) |
channels | 3 | RGB color images |
hidden_dims | [128, 256, 256, 256] | Channel counts at each resolution |
time_dim | 128 | Time embedding dimension |
dropout_p | 0.1 | Dropout probability |
The hidden dimensions create four resolution levels: 32×32 → 16×16 → 8×8 → 4×4, with attention applied selectively at the 16×16 resolution.
Key architectural enhancements
1. Residual blocks with dropout
CIFAR-10 uses an enhanced ResBlock with dropout to prevent overfitting:
class ResBlockWithDropout(ResBlock):
def __init__(self, in_ch, out_ch, time_dim, dropout_p=0.02):
super().__init__(in_ch, out_ch, time_dim)
self.dropout = nn.Dropout2d(p=dropout_p)
def forward(self, x, t_emb):
h = self.block1(x)
h = self.dropout(h) # Dropout after first conv
h = h + self.time_emb(t_emb)[:, :, None, None]
h = self.block2(h)
return self.shortcut(x) + h
The dropout is applied after the first convolution block but before the time embedding injection.
Dropout rates vary by location: 0.02 in residual blocks, 0.05 in bottleneck attention blocks, and 0.1 in down/up blocks for stronger regularization.
2. Multi-scale attention
Unlike MNIST which only uses attention in the bottleneck, CIFAR-10 applies self-attention at specific resolutions:
# Attention is used at the 16×16 resolution (index 1)
attention_resolutions = [1]
self.down_blocks = nn.ModuleList([
DownBlockWithAttention(
hidden_dims[i],
hidden_dims[i + 1],
time_dim,
dropout_p,
use_attention=(i in attention_resolutions), # Only at 16×16
)
for i in range(len(hidden_dims) - 1)
])
Attention placement:
- 32×32 (index 0): No attention (too expensive)
- 16×16 (index 1): Self-attention enabled
- 8×8 (index 2): No attention
- 4×4 (bottleneck): Self-attention enabled
Applying attention at 16×16 captures mid-level features like object parts and textures, while bottleneck attention at 4×4 captures global structure.
3. Down blocks with attention
class DownBlockWithAttention(nn.Module):
def __init__(self, in_ch, out_ch, time_dim, dropout_p=0.05, use_attention=False):
super().__init__()
self.res = ResBlockWithDropout(in_ch, out_ch, time_dim, dropout_p)
self.attention = SelfAttention(out_ch) if use_attention else nn.Identity()
self.pool = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1)
def forward(self, x, t_emb):
h = self.res(x, t_emb)
h = self.attention(h) # Attention before downsampling
return self.pool(h), h
Attention is applied after the residual block but before downsampling to preserve high-resolution features.
4. Up blocks with attention
class UpBlockWithAttention(nn.Module):
def __init__(self, in_ch, skip_ch, out_ch, time_dim, dropout_p=0.05, use_attention=False):
super().__init__()
self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
self.res = ResBlockWithDropout(out_ch + skip_ch, out_ch, time_dim, dropout_p)
self.attention = SelfAttention(out_ch) if use_attention else nn.Identity()
def forward(self, x, skip, t_emb):
x = self.up(x)
x = torch.cat([x, skip], dim=1)
x = self.res(x, t_emb)
x = self.attention(x) # Attention after skip fusion
return x
5. Enhanced bottleneck
The bottleneck uses multiple attention layers for stronger feature learning:
class BottleneckWithAttention(nn.Module):
def __init__(self, ch, time_dim, dropout_p=0.05):
super().__init__()
self.res1 = ResBlockWithDropout(ch, ch, time_dim, dropout_p)
self.attention = SelfAttention(ch)
self.res2 = ResBlockWithDropout(ch, ch, time_dim, dropout_p)
def forward(self, x, t_emb):
x = self.res1(x, t_emb)
x = self.attention(x)
x = self.res2(x, t_emb)
return x
Network structure
Encoder path
self.down_blocks = nn.ModuleList([
DownBlockWithAttention(128, 256, 128, 0.1, use_attention=False), # 32→16, no attn
DownBlockWithAttention(256, 256, 128, 0.1, use_attention=True), # 16→8, WITH attn
DownBlockWithAttention(256, 256, 128, 0.1, use_attention=False), # 8→4, no attn
])
Channel and resolution progression:
- 32×32 @ 128 channels → 16×16 @ 256 channels
- 16×16 @ 256 channels → 8×8 @ 256 channels (with attention)
- 8×8 @ 256 channels → 4×4 @ 256 channels
Bottleneck
self.bottleneck = BottleneckWithAttention(256, 128, dropout_p=0.1)
Operates at 4×4 @ 256 channels with self-attention.
Decoder path
self.up_blocks = nn.ModuleList([
UpBlockWithAttention(256, 256, 256, 128, 0.1, use_attention=False), # 4→8
UpBlockWithAttention(256, 256, 256, 128, 0.1, use_attention=True), # 8→16, WITH attn
UpBlockWithAttention(256, 256, 128, 128, 0.1, use_attention=False), # 16→32
])
Mirrors the encoder’s attention pattern for symmetry.
Parameter count
The CIFAR-10 model is significantly larger than MNIST:
- Total parameters: ~12.5M
- Time embedding: ~132K parameters
- Encoder: ~4.2M parameters
- Bottleneck: ~2.8M parameters
- Decoder: ~5.1M parameters
- Output layers: ~9K parameters
This is approximately 10× larger than the MNIST model, reflecting the increased complexity of natural images.
Training configuration
The CIFAR-10 model uses AdamW with weight decay for better generalization:
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=2e-4,
weight_decay=1e-5,
betas=(0.9, 0.999),
)
Additional training features:
- Gradient clipping:
max_norm=1.0 to prevent exploding gradients
- Linear beta schedule instead of cosine
- Exponential Moving Average (EMA) for sampling
Noise schedule
CIFAR-10 uses a linear schedule instead of cosine:
self.beta_schedule = torch.linspace(
beta_start, beta_end, noise_steps, device=self.device
)
With beta_start=1e-4 and beta_end=0.02 over 1000 steps.
The linear schedule is preferred for CIFAR-10 as it matches the original DDPM paper and provides more aggressive early denoising for complex images.
Exponential Moving Average (EMA)
CIFAR-10 maintains an EMA copy of the model weights for sampling:
# EMA copy used only for sampling
self.ema_model = DiffusionModelCIFAR(...)
self.ema_model.load_state_dict(self.model.state_dict())
self.ema_model.eval()
self.ema_decay = 0.999
# Update after each training step
with torch.no_grad():
for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()):
ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay)
The EMA model is used for generation instead of the training model, providing more stable and higher-quality samples.
Comparison with MNIST
| Aspect | MNIST | CIFAR-10 |
|---|
| Hidden dims | [32, 64, 128] | [128, 256, 256, 256] |
| Parameters | ~1.2M | ~12.5M |
| Attention | Bottleneck only | Bottleneck + 16×16 |
| Dropout | None | 0.02-0.1 (varied) |
| Optimizer | Adam | AdamW + weight decay |
| Beta schedule | Cosine | Linear |
| EMA | No | Yes (decay=0.999) |
| Grad clipping | No | Yes (max_norm=1.0) |
Usage example
from src.models.diffusion_cifar import DiffusionProcessCIFAR
# Initialize with CIFAR-10 parameters
diffusion = DiffusionProcessCIFAR(
image_size=32,
channels=3,
hidden_dims=[128, 256, 256, 256],
noise_steps=1000,
dropout_p=0.1,
ema_decay=0.999,
device=torch.device('cuda')
)
# Train on CIFAR-10 data
for batch in cifar_loader:
loss = diffusion.train_step(batch)
# Generate samples using EMA model
samples = diffusion.sample(num_samples=64)
See the CIFAR-10 training guide for complete training code and hyperparameter tuning recommendations.