Overview
This utility offers three main capabilities:- Per-timestep loss analysis: Measure noise prediction error across diffusion timesteps
- DDPM noise interpolation: Generate smooth transitions between random noise vectors
- DDIM noise interpolation: Deterministic interpolation using DDIM sampling
Usage
Run the script directly to generate interpolation grids:Prerequisites
- Trained MNIST model at
best_model.ptin project root - MNIST dataset will be downloaded automatically if not present
Expected outputs
All outputs are saved tosamples/:
- interp.png: DDPM stochastic interpolation grid (8×9)
- interp_ddim.png: DDIM deterministic interpolation grid (8×9)
- Timestep loss plot: matplotlib figure showing error vs timestep buckets
Functions
per_timestep_loss
Analyzes noise prediction error across diffusion timesteps by bucketing and averaging losses. Signature (frominterpolation_and_timesteps.py:23-58):
| Parameter | Type | Description | Default |
|---|---|---|---|
diffusion | DiffusionProcess | Trained diffusion model | Required |
loader | DataLoader | PyTorch data loader with real images | Required |
num_batches | int | Number of batches to evaluate | 10 |
buckets | int | Number of timestep buckets (T/buckets per bucket) | 10 |
device | str | Compute device | 'cuda' |
torch.Tensor of shape (buckets,) containing average MSE loss per bucket
Implementation:
- Splits timesteps 0-999 into equal buckets (e.g., 10 buckets = 100 steps each)
- For each batch:
- Sample random timesteps
tuniformly from [0, T) - Add noise:
x_t, noise = diffusion.add_noise(x, t) - Predict noise:
pred = model(x_t, t) - Compute per-sample MSE:
((pred - noise)**2).mean(dim=(1,2,3)) - Accumulate into appropriate bucket based on
t
- Sample random timesteps
- Average losses across batches
- Plot bar chart showing error vs timestep range
- X-axis: Timestep buckets (e.g., “0-99”, “100-199”, …)
- Y-axis: MSE (ε-prediction error)
- Title: “Noise-prediction error vs timestep”
sample_from_xt
Generates samples using standard DDPM stochastic reverse process from initial noise. Signature (frominterpolation_and_timesteps.py:60-73):
| Parameter | Type | Description |
|---|---|---|
diffusion | DiffusionProcess | Trained diffusion model |
x_T | torch.Tensor | Initial noise of shape (B, C, H, W) |
torch.Tensor of generated images (same shape as x_T)
Implementation:
Performs full 1000-step DDPM reverse diffusion:
interpolate_noise_and_generate
Generates an interpolation grid by linearly interpolating between two random noise vectors and sampling with DDPM. Signature (frominterpolation_and_timesteps.py:75-95):
| Parameter | Type | Description | Default |
|---|---|---|---|
diffusion | DiffusionProcess | Trained diffusion model | Required |
n | int | Number of parallel interpolations (rows) | 8 |
steps | int | Number of interpolation steps (columns) | 7 |
save_path | str | Output path for grid image | 'interp.png' |
str path to saved image
Implementation:
- Generate two random noise endpoints:
z0, z1 ~ N(0, I) - Create linear interpolation:
z_α = (1-α)z0 + αz1forα ∈ [0, 1] - Sample from each interpolated noise:
x = sample_from_xt(diffusion, z_α) - Arrange as grid with
nrows andstepscolumns - Save to
save_pathusing torchvision.utils.save_image
ddim_sample_from_xt
Generates samples using deterministic DDIM reverse process (η=0) from initial noise. Signature (frominterpolation_and_timesteps.py:97-112):
| Parameter | Type | Description |
|---|---|---|
diffusion | DiffusionProcess | Trained diffusion model |
x_T | torch.Tensor | Initial noise of shape (B, C, H, W) |
torch.Tensor of generated images (same shape as x_T)
Implementation:
Performs deterministic DDIM sampling with η=0:
- Consistency: Same
x_Talways produces samex_0 - Invertibility: Can encode images back to noise
- Interpolation: Smooth transitions in latent space
interpolate_noise_and_generate_ddim
Generates an interpolation grid using deterministic DDIM sampling. Signature (frominterpolation_and_timesteps.py:114-139):
| Parameter | Type | Description | Default |
|---|---|---|---|
diffusion | DiffusionProcess | Trained diffusion model | Required |
n | int | Number of parallel interpolations (rows) | 8 |
steps | int | Number of interpolation steps (columns) | 7 |
save_path | str | Output path for grid image | "interp.png" |
str path to saved image
Difference from DDPM version:
Uses ddim_sample_from_xt() instead of sample_from_xt(), providing:
- Deterministic output: No randomness in reverse process
- Smoother interpolations: Consistent mapping from noise to images
- Better quality: Determinism often produces cleaner transitions
Main script execution
When run as__main__, the script (from interpolation_and_timesteps.py:142-193):
Initialization
Dataset loading
Analysis and generation
-
Timestep loss diagnostic (optional, displays matplotlib plot):
-
DDPM interpolation grid:
-
DDIM interpolation grid:
Output examples
Timestep loss plot
Typical loss pattern shows higher error at middle timesteps:Interpolation grid structure
File:samples/interp.png (DDPM) or samples/interp_ddim.png (DDIM)
Dimensions: 8 rows × 9 columns = 72 total images
Interpretation:
- Each row: Independent interpolation between two random endpoints
- Left column (α=0.0): First random noise vector endpoint
- Right column (α=1.0): Second random noise vector endpoint
- Middle columns: Linear interpolation in noise space
- Smooth semantic transitions demonstrate learned manifold structure
DDPM vs DDIM comparison
DDPM (interp.png):
- Stochastic sampling adds noise at each step
- Slight variations in intermediate steps
- May show more diversity but less smoothness
interp_ddim.png):
- Deterministic sampling (η=0)
- Perfectly smooth transitions
- Consistent and reproducible interpolations
- Generally cleaner visual quality
Use cases
Diagnostic analysis
Useper_timestep_loss() to:
- Identify problematic timestep ranges during training
- Validate that model learns all diffusion stages
- Compare different model architectures or hyperparameters
- Debug training issues (e.g., collapsed loss at certain timesteps)
Interpolation visualization
Use interpolation functions to:- Demonstrate learned latent space structure
- Generate smooth transitions between concepts
- Create visualizations for papers or presentations
- Validate model quality (smooth interpolations = good generalization)
- Compare DDPM vs DDIM sampling quality
Research applications
- Latent space exploration: Understand what the model learns
- Semantic interpolation: Find meaningful directions in noise space
- Model comparison: Evaluate different training configurations
- Quality metrics: Smooth interpolations correlate with sample quality
Performance notes
Computational cost
per_timestep_loss(): Fast, ~1-2 minutes for 10 batchesinterpolate_noise_and_generate(): Slow, ~8-10 minutes for 8×9 grid (requires 72 full DDPM samplings)interpolate_noise_and_generate_ddim(): Same as DDPM, ~8-10 minutes (full 1000 steps)
Related functions
DiffusionProcess.add_noise(): Forward diffusion noise additionDiffusionProcess.sample(): Standard DDPM samplingDiffusionProcess.sample_ddim(): Accelerated DDIM sampling- See Diffusion process for core diffusion operations