This guide walks through training a DDPM model on the MNIST handwritten digit dataset. The training script demonstrates the core diffusion training loop with visualization and early stopping.
Prerequisites
Before starting, ensure you have the required dependencies installed:
pip install torch torchvision matplotlib tqdm
Training configuration
The MNIST training uses a lightweight configuration suitable for quick experimentation:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 50
batch_size = 128
image_size = 28
channels = 1 # Grayscale images
Model architecture
The U-Net model is configured with a compact architecture:
diffusion = DiffusionProcess(
image_size=28,
channels=1,
hidden_dims=[128, 256, 512],
device=device
)
Data preparation
MNIST images are normalized to the range [-1, 1] to match the diffusion model’s output range:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)), # [0,1] → [-1,1]
])
dataset = datasets.MNIST(
root="./data",
train=True,
download=True,
transform=transform
)
DataLoader setup
The dataloader is optimized for GPU training with worker processes:
num_workers = min(8, os.cpu_count() or 4)
loader = DataLoader(
dataset,
batch_size=128,
shuffle=True,
num_workers=num_workers,
pin_memory=(device.type == "cuda"),
persistent_workers=num_workers > 0,
)
Training loop
Execute the training script from the project root:
python src/training/train_diffusion.py
Monitor training progress
The script outputs epoch loss and generates samples periodically:
Epoch 1/50
Epoch 2/50
...
Epoch 10/50
Every 10 epochs, the script generates:
Forward noising visualization showing noise progression
Sample images from the trained model
Training loss curve
Training features
Early stopping
The training implements early stopping with patience to prevent overfitting:
src/training/train_diffusion.py
patience = 4
best_loss = float("inf")
wait = 0
for epoch in range(epochs):
# ... training code ...
if avg_loss < best_loss - 1e-4:
best_loss = avg_loss
wait = 0
torch.save(diffusion.model.state_dict(), "best_model.pt")
else:
wait += 1
if wait >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
Visualization utilities
Two key visualization functions help monitor training:
Forward noising visualization (src/training/train_diffusion.py:59-71):
@torch.no_grad()
def visualize_noising(x0, diffusion, timesteps=[0, 200, 400, 600, 800, 999]):
"""Show how noise gradually corrupts the image"""
x0 = x0[:8].to(device)
for i, t in enumerate(timesteps):
t_batch = torch.full((x0.size(0),), t, device=device, dtype=torch.long)
x_t, _ = diffusion.add_noise(x0, t_batch)
# ... plotting code ...
Denoising visualization (src/training/train_diffusion.py:76-104):
@torch.no_grad()
def visualize_sampling(diffusion, num_samples=16, steps=[999, 800, 400, 0]):
"""Show the reverse denoising process"""
x_t = torch.randn(num_samples, channels, image_size, image_size, device=device)
for t in reversed(range(diffusion.noise_steps)):
# ... denoising steps ...
Output files
The training script generates several outputs in the samples/ directory:
beta_schedule.png - Visualization of the cosine beta schedule
noising_epoch{N}.png - Forward noising process at epoch N
samples_epoch{N}.png - Generated samples at epoch N
training_curve.png - Loss curve over epochs
DDPM.png - Final generated samples
best_model.pt - Best model checkpoint
CUDA optimizations
The script includes GPU-specific optimizations:
src/training/train_diffusion.py
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
cudnn.benchmark=True enables cuDNN’s autotuner to find the best convolution algorithms for your hardware, providing significant speedups for fixed-size inputs.
Expected results
With default hyperparameters:
- Training time: 10-15 minutes on a modern GPU
- Memory usage: ~2-3 GB VRAM
- Final loss: ~0.01-0.02 MSE
- Sample quality: Recognizable digits with some noise
If you encounter out-of-memory errors, reduce batch_size from 128 to 64 or 32.
Next steps
- Try training on CIFAR-10 for color images: Training on CIFAR-10
- Scale up training on HPC clusters: HPC SLURM Guide
- Experiment with different architectures in
src/models/diffusion.py