Overview
TheDiffusionModel class implements a U-Net architecture with time embeddings for predicting noise in diffusion models. It features an encoder-decoder structure with skip connections, residual blocks, self-attention in the bottleneck, and sinusoidal time embeddings.
Constructor
Parameters
Height and width of the square input images.
Number of image channels (e.g., 1 for grayscale, 3 for RGB).
List of hidden dimensions for each level of the U-Net. Determines the capacity and depth of the network. Each entry creates one down/up block pair.
Dimensionality of the time embedding. This embedding is injected into each residual block to condition the network on the timestep.
Attributes
Time embedding module that converts timestep integers to sinusoidal embeddings processed through an MLP.
Initial convolution layer mapping from
channels to hidden_dims[0] channels with kernel size 3.List of downsampling blocks. Each
DownBlock applies a residual block followed by 2x spatial downsampling.Bottleneck block at the coarsest resolution, containing two residual blocks with a self-attention layer in between.
List of upsampling blocks. Each
UpBlock performs 2x spatial upsampling, concatenates skip connections, and applies a residual block.Group normalization layer before the final convolution, with 8 groups.
Final convolution layer mapping from
hidden_dims[0] back to channels with kernel size 3.Methods
forward
Forward pass through the U-Net model.Parameters
Input noisy images tensor of shape
[batch_size, channels, height, width].Timesteps tensor of shape
[batch_size] containing integer timestep indices.Returns
Predicted noise tensor of shape
[batch_size, channels, height, width].Implementation
The forward pass follows these steps:- Time embedding: Converts timestep indices to sinusoidal embeddings via
time_mlp - Initial convolution: Projects input to the first hidden dimension
- Encoder path: Processes through
down_blocks, storing skip connections - Bottleneck: Applies residual blocks with self-attention at the coarsest resolution
- Decoder path: Processes through
up_blocks, incorporating skip connections from the encoder - Output: Applies group normalization, SiLU activation, and final convolution to predict noise
Architecture components
TimeEmbedding
Converts integer timesteps to continuous embeddings using sinusoidal positional encoding followed by an MLP.- Input: Timestep tensor
[batch_size] - Output: Time embedding
[batch_size, time_dim] - Structure: Sinusoidal encoding → Linear(dim, dim4) → GELU → Linear(dim4, dim)
ResBlock
Residual block with time embedding injection. Parameters:in_ch(int): Input channelsout_ch(int): Output channelstime_dim(int): Time embedding dimension
- GroupNorm(8) → SiLU → Conv2d(3x3)
- Add projected time embedding
- GroupNorm(8) → SiLU → Conv2d(3x3)
- Skip connection with optional 1x1 conv for channel matching
DownBlock
Downsampling block combining residual processing with spatial reduction. Parameters:in_ch(int): Input channelsout_ch(int): Output channelstime_dim(int): Time embedding dimension
- ResBlock → Conv2d(4x4, stride=2) for 2x downsampling
UpBlock
Upsampling block with skip connection fusion. Parameters:in_ch(int): Input channels before upsamplingskip_ch(int): Skip connection channelsout_ch(int): Output channelstime_dim(int): Time embedding dimension
- ConvTranspose2d(4x4, stride=2) for 2x upsampling
- Concatenate with skip connection
- ResBlock
BottleneckBlock
Bottleneck processing at the coarsest resolution. Parameters:ch(int): Channel dimensiontime_dim(int): Time embedding dimension
- ResBlock → SelfAttention → ResBlock
SelfAttention
Multi-head self-attention layer for spatial feature relationships. Parameters:ch(int): Channel dimensionnum_heads(int, default=4): Number of attention heads
- GroupNorm → Multi-head attention → Projection → Residual connection
Usage example
Architecture diagram
For a model withhidden_dims=[32, 64, 128]:
Related classes
- DiffusionProcess - Training and sampling logic using this model
- DiffusionModelCIFAR - Enhanced variant with dropout and configurable attention