Overview
The MNIST diffusion model uses a lightweight U-Net architecture optimized for generating 28×28 grayscale digits. This simplified design serves as an excellent starting point for understanding diffusion models before scaling to more complex datasets.Model specification
The MNIST U-Net is defined insrc/models/diffusion.py as the DiffusionModel class:
Architecture parameters
| Parameter | Value | Description |
|---|---|---|
image_size | 28 | Input image dimensions (28×28) |
channels | 1 | Grayscale images |
hidden_dims | [32, 64, 128] | Channel counts at each resolution |
time_dim | 128 | Time embedding dimension |
The hidden dimensions [32, 64, 128] create three resolution levels: 28×28 → 14×14 → 7×7, with the bottleneck operating at 7×7.
Network structure
Initial convolution
The input image is first processed by an initial convolution that maps from 1 channel (grayscale) to 32 channels:Encoder blocks
The encoder consists of twoDownBlock layers that progressively downsample:
- Input: 28×28 @ 32 channels
- After 1st down: 14×14 @ 64 channels
- After 2nd down: 7×7 @ 128 channels
Bottleneck
At the coarsest resolution (7×7), the bottleneck applies self-attention:- First
ResBlockwith time conditioning SelfAttentionlayer for global context- Second
ResBlockwith time conditioning
Even at the small 7×7 resolution, self-attention helps the model capture relationships between different parts of the digit (e.g., connecting the top and bottom of an “8”).
Decoder blocks
The decoder upsamples back to the original resolution using skip connections:- Input: 7×7 @ 128 channels
- After 1st up: 14×14 @ 64 channels (fused with skip)
- After 2nd up: 28×28 @ 32 channels (fused with skip)
Output layers
The final layers map from 32 channels back to 1 channel (the predicted noise):Parameter count
The MNIST model is deliberately compact to enable fast training:- Total parameters: ~1.2M
- Time embedding: ~132K parameters
- Encoder: ~290K parameters
- Bottleneck: ~530K parameters
- Decoder: ~210K parameters
- Output layers: ~3K parameters
Training configuration
The model is trained with the following setup:- Optimizer: Adam with learning rate 1e-4
- Loss function: MSE between predicted and actual noise
- Noise schedule: Cosine beta schedule with 1000 steps
- Mixed precision: Enabled on CUDA for faster training
The cosine schedule is preferred over linear for MNIST as it provides more gradual noise addition, which works better for simple images.
Inference process
During sampling, the model iteratively denoises random noise:Design tradeoffs
The MNIST architecture makes several simplifications compared to more complex models:| Aspect | MNIST Choice | Rationale |
|---|---|---|
| Channels | [32, 64, 128] | Sufficient for simple digits |
| Attention | Bottleneck only | 7×7 is small enough for single attention |
| Dropout | None | MNIST is large enough to avoid overfitting |
| EMA | Not used | Adam optimizer is stable enough |
These simplifications make the MNIST model ideal for learning and experimentation, but would be insufficient for complex natural images.