Skip to main content

Overview

The MNIST diffusion model uses a lightweight U-Net architecture optimized for generating 28×28 grayscale digits. This simplified design serves as an excellent starting point for understanding diffusion models before scaling to more complex datasets.

Model specification

The MNIST U-Net is defined in src/models/diffusion.py as the DiffusionModel class:
class DiffusionModel(nn.Module):
    def __init__(self, image_size, channels, hidden_dims=[32, 64, 128], time_dim=128):
        super().__init__()
        self.image_size = image_size  # 28 for MNIST
        self.channels = channels      # 1 for grayscale
        self.hidden_dims = hidden_dims
        self.time_dim = time_dim

Architecture parameters

ParameterValueDescription
image_size28Input image dimensions (28×28)
channels1Grayscale images
hidden_dims[32, 64, 128]Channel counts at each resolution
time_dim128Time embedding dimension
The hidden dimensions [32, 64, 128] create three resolution levels: 28×28 → 14×14 → 7×7, with the bottleneck operating at 7×7.

Network structure

Initial convolution

The input image is first processed by an initial convolution that maps from 1 channel (grayscale) to 32 channels:
self.init_conv = nn.Conv2d(channels, hidden_dims[0], 3, padding=1)

Encoder blocks

The encoder consists of two DownBlock layers that progressively downsample:
self.down_blocks = nn.ModuleList([
    DownBlock(hidden_dims[0], hidden_dims[1], time_dim),  # 28×28 → 14×14, 32 → 64 channels
    DownBlock(hidden_dims[1], hidden_dims[2], time_dim)   # 14×14 → 7×7, 64 → 128 channels
])
Resolution progression:
  • Input: 28×28 @ 32 channels
  • After 1st down: 14×14 @ 64 channels
  • After 2nd down: 7×7 @ 128 channels

Bottleneck

At the coarsest resolution (7×7), the bottleneck applies self-attention:
self.bottleneck = BottleneckBlock(hidden_dims[2], time_dim)  # 128 channels
The bottleneck structure includes:
  • First ResBlock with time conditioning
  • SelfAttention layer for global context
  • Second ResBlock with time conditioning
Even at the small 7×7 resolution, self-attention helps the model capture relationships between different parts of the digit (e.g., connecting the top and bottom of an “8”).

Decoder blocks

The decoder upsamples back to the original resolution using skip connections:
self.up_blocks = nn.ModuleList([
    UpBlock(hidden_dims[2], hidden_dims[2], hidden_dims[1], time_dim),  # 7×7 → 14×14
    UpBlock(hidden_dims[1], hidden_dims[1], hidden_dims[0], time_dim)   # 14×14 → 28×28
])
Resolution progression:
  • Input: 7×7 @ 128 channels
  • After 1st up: 14×14 @ 64 channels (fused with skip)
  • After 2nd up: 28×28 @ 32 channels (fused with skip)

Output layers

The final layers map from 32 channels back to 1 channel (the predicted noise):
self.out_norm = nn.GroupNorm(8, hidden_dims[0])
self.out_conv = nn.Conv2d(hidden_dims[0], channels, 3, padding=1)

Parameter count

The MNIST model is deliberately compact to enable fast training:
  • Total parameters: ~1.2M
  • Time embedding: ~132K parameters
  • Encoder: ~290K parameters
  • Bottleneck: ~530K parameters
  • Decoder: ~210K parameters
  • Output layers: ~3K parameters
This is approximately 10× smaller than the CIFAR-10 model, reflecting the simpler nature of grayscale digits.

Training configuration

The model is trained with the following setup:
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
Key training details:
  • Optimizer: Adam with learning rate 1e-4
  • Loss function: MSE between predicted and actual noise
  • Noise schedule: Cosine beta schedule with 1000 steps
  • Mixed precision: Enabled on CUDA for faster training
The cosine schedule is preferred over linear for MNIST as it provides more gradual noise addition, which works better for simple images.

Inference process

During sampling, the model iteratively denoises random noise:
def sample(self, num_samples=16):
    self.model.eval()
    with torch.no_grad():
        # Start with pure noise
        x_t = torch.randn(num_samples, self.model.channels, 
                        self.model.image_size, self.model.image_size,
                        device=self.device)
        
        # Gradually denoise over 1000 steps
        for t in reversed(range(self.noise_steps)):
            t_batch = torch.full((num_samples,), t, device=self.device, dtype=torch.long)
            predicted_noise = self.model(x_t, t_batch)
            
            # Update using DDPM formula
            model_mean = sqrt_recip_alpha_t * (
                x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * predicted_noise
            )
            
            if t > 0:
                noise = torch.randn_like(x_t)
                x_t = model_mean + sigma_t * noise
            else:
                x_t = model_mean
        
        return torch.clamp(x_t, -1, 1)

Design tradeoffs

The MNIST architecture makes several simplifications compared to more complex models:
AspectMNIST ChoiceRationale
Channels[32, 64, 128]Sufficient for simple digits
AttentionBottleneck only7×7 is small enough for single attention
DropoutNoneMNIST is large enough to avoid overfitting
EMANot usedAdam optimizer is stable enough
These simplifications make the MNIST model ideal for learning and experimentation, but would be insufficient for complex natural images.

Usage example

To create and use the MNIST diffusion model:
from src.models.diffusion import DiffusionProcess

# Initialize with MNIST parameters
diffusion = DiffusionProcess(
    image_size=28,
    channels=1,
    hidden_dims=[32, 64, 128],
    noise_steps=1000,
    device=torch.device('cuda')
)

# Train on MNIST data
for batch in mnist_loader:
    loss = diffusion.train_step(batch)

# Generate samples
samples = diffusion.sample(num_samples=64)
See the training guide for complete training code and hyperparameter recommendations.

Build docs developers (and LLMs) love