Overview
Training script for a diffusion model on the MNIST dataset. Implements a simple DDPM with cosine beta scheduling, early stopping, and periodic visualization of the noising and sampling process.Usage
Configuration parameters
Training hyperparameters
| Parameter | Type | Default | Description |
|---|---|---|---|
epochs | int | 50 | Maximum number of training epochs |
batch_size | int | 128 | Batch size for training |
image_size | int | 28 | Image dimensions (28x28 for MNIST) |
channels | int | 1 | Number of input channels (grayscale) |
save_dir | str | "samples" | Directory for saving visualization outputs |
Model architecture
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dims | list[int] | [128, 256, 512] | Hidden dimensions for U-Net layers |
noise_steps | int | 1000 (default) | Number of diffusion timesteps |
Early stopping
| Parameter | Type | Default | Description |
|---|---|---|---|
patience | int | 4 | Number of epochs to wait before early stopping |
min_delta | float | 1e-4 | Minimum loss improvement threshold |
Device configuration
- Auto-detection: Uses CUDA if available, otherwise CPU
- CUDA optimizations: Enables cuDNN benchmark and high precision matmul when available
- Data loading: Automatically configures
num_workers(min of 8 or CPU count)
Data preprocessing
Applies the following transformations to MNIST images:Training process
Loss function
Mean squared error (MSE) between predicted and actual noise.Optimization
- Optimizer: Adam (configured in
DiffusionProcess) - Loss tracking: Records average loss per epoch
- Best model saving: Saves checkpoint when loss improves by >
1e-4
Early stopping logic
Outputs
Generated files
| File | Description |
|---|---|
samples/beta_schedule.png | Visualization of cosine beta schedule |
samples/noising_epoch{n}.png | Forward noising visualization (every 10 epochs) |
samples/samples_epoch{n}.png | Generated samples (every 10 epochs) |
samples/training_curve.png | Training loss curve |
DDPM.png | Final generated samples (16 images) |
best_model.pt | Best model checkpoint based on training loss |
Visualization schedule
- Epoch 1: Initial noising and samples
- Every 10 epochs: Updated noising and samples
- End of training: Final samples and loss curve
Code example
Customize training parameters by modifying the configuration section:Utilities
visualize_noising
Visualizes the forward noising process at specific timesteps. Parameters:x0: Initial clean images (batch)diffusion: DiffusionProcess instancetimesteps: List of timesteps to visualize (default:[0, 200, 400, 600, 800, 999])fname: Output filename
visualize_sampling
Visualizes the reverse denoising process (sampling). Parameters:diffusion: DiffusionProcess instancenum_samples: Number of samples to generate (default: 16)steps: Timesteps to visualize (default:[999, 800, 400, 0])fname: Output filename
Source
Location:src/training/train_diffusion.py