Latent space interpolation reveals the structure of the learned data distribution by smoothly transitioning between two random noise vectors. This is particularly effective with DDIM’s deterministic sampling.
Concept
Instead of sampling from independent noise vectors, we can interpolate between two noise vectors z0 and z1 at timestep T:
zα=(1−α)z0+αz1,α∈[0,1]
Then run the full reverse diffusion process from each interpolated zα to generate the final images.
DDIM’s deterministic trajectories (η=0) produce smooth, meaningful interpolations. DDPM’s stochastic sampling can lead to less coherent transitions.
DDPM interpolation
This implementation uses standard DDPM sampling for each interpolation point. From src/utilities/interpolation_and_timesteps.py:76:
@torch.no_grad()
def interpolate_noise_and_generate(diffusion, n=8, steps=7, save_path='interp.png'):
"""
Interpolate between two random noise vectors and generate images.
Args:
diffusion: DiffusionProcess instance
n: Number of samples in each column
steps: Number of interpolation steps (columns in the grid)
save_path: Where to save the output grid
Returns:
Path to saved image grid
"""
# Two random noise endpoints at T
z0 = torch.randn(
n,
diffusion.model.channels,
diffusion.model.image_size,
diffusion.model.image_size,
device=next(diffusion.model.parameters()).device
)
z1 = torch.randn_like(z0)
# Linear interpolation coefficients
alphas = torch.linspace(0, 1, steps, device=z0.device)
cols = []
for a in alphas:
z = (1-a)*z0 + a*z1
x = sample_from_xt(diffusion, z)
cols.append((x+1)/2) # Normalize to [0, 1]
from torchvision.utils import save_image
grid = torch.cat(cols, dim=0)
save_image(grid, save_path, nrow=n) # Each column is a different α
return save_path
The sample_from_xt helper performs the reverse process from a given noise vector:
@torch.no_grad()
def sample_from_xt(diffusion, x_T):
"""Sample from the diffusion process starting from x_T."""
model, T, dev = diffusion.model, diffusion.noise_steps, x_T.device
x_t = x_T.clone()
for t in reversed(range(T)):
t_b = torch.full((x_t.size(0),), t, device=dev, dtype=torch.long)
pred_eps = model(x_t, t_b)
beta_t = diffusion.beta_schedule[t]
alpha_t = diffusion.alpha_schedule[t]
sqrt_one_minus_alpha_cumprod_t = diffusion.sqrt_one_minus_alpha_cumprod[t]
x_t = (1/torch.sqrt(alpha_t)) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * pred_eps)
if t > 0:
x_t += torch.sqrt(beta_t) * torch.randn_like(x_t)
return x_t
DDIM interpolation
DDIM produces smoother interpolations due to deterministic trajectories. From src/utilities/interpolation_and_timesteps.py:115:
@torch.no_grad()
def interpolate_noise_and_generate_ddim(diffusion, n=8, steps=7, save_path="interp.png"):
"""
Interpolation grid using deterministic DDIM sampling from x_T.
Args:
diffusion: DiffusionProcess instance
n: Number of samples in each column
steps: Number of interpolation steps
save_path: Output path
Returns:
Path to saved image grid
"""
dev = next(diffusion.model.parameters()).device
C, H, W = (
diffusion.model.channels,
diffusion.model.image_size,
diffusion.model.image_size,
)
z0 = torch.randn(n, C, H, W, device=dev)
z1 = torch.randn_like(z0)
alphas = torch.linspace(0, 1, steps, device=dev)
cols = []
for a in alphas:
z = (1 - a) * z0 + a * z1
x = ddim_sample_from_xt(diffusion, z)
cols.append((x + 1) / 2)
from torchvision.utils import save_image
grid = torch.cat(cols, dim=0)
save_image(grid, save_path, nrow=n)
print(f"Saved {save_path}")
return save_path
The DDIM sampling helper uses deterministic updates:
@torch.no_grad()
def ddim_sample_from_xt(diffusion, x_T):
"""Deterministic (η=0) reverse process for a batch of initial noises x_T."""
model, T, dev = diffusion.model, diffusion.noise_steps, x_T.device
x_t = x_T.clone()
abar = diffusion.alpha_cumprod # \bar{α}_t
for t in reversed(range(T)):
t_b = torch.full((x_t.size(0),), t, device=dev, dtype=torch.long)
eps_hat = model(x_t, t_b)
# x0 prediction (standard ε-prediction parameterization)
x0_hat = (x_t - torch.sqrt(1 - abar[t]) * eps_hat) / torch.sqrt(abar[t])
if t > 0:
x_t = torch.sqrt(abar[t-1]) * x0_hat + torch.sqrt(1 - abar[t-1]) * eps_hat
else:
x_t = x0_hat
return x_t
Usage example
Load trained model
Set up the diffusion process with a trained checkpoint:import os
import torch
from src.models.diffusion import DiffusionProcess
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
diffusion = DiffusionProcess(
image_size=28,
channels=1,
hidden_dims=[128, 256, 512],
device=device,
)
diffusion.model.load_state_dict(
torch.load('best_model.pt', map_location=device)
)
diffusion.model.eval()
Generate DDPM interpolation
Create an interpolation grid using stochastic DDPM:from src.utilities.interpolation_and_timesteps import interpolate_noise_and_generate
interpolate_noise_and_generate(
diffusion,
n=8, # 8 samples per column
steps=9, # 9 interpolation steps
save_path='samples/interp.png'
)
Generate DDIM interpolation
Create a smoother interpolation with deterministic DDIM:from src.utilities.interpolation_and_timesteps import interpolate_noise_and_generate_ddim
interpolate_noise_and_generate_ddim(
diffusion,
n=8,
steps=9,
save_path='samples/interp_ddim.png'
)
Interpreting results
The generated grid has:
- Rows: Independent samples (different random seeds within the batch)
- Columns: Interpolation steps from α=0 (left, pure z0) to α=1 (right, pure z1)
Look for smooth semantic transitions in DDIM interpolations. For MNIST, you might see gradual digit morphing (e.g., 3 → 8). For CIFAR-10, you might see color/texture changes or object transformations.
Comparing DDPM vs DDIM interpolations
| Aspect | DDPM | DDIM |
|---|
| Smoothness | Less smooth, stochastic jumps | Very smooth transitions |
| Reproducibility | Different every run | Identical every run |
| Semantic meaning | Harder to interpret | Clear semantic paths |
| Speed | All 1000 steps required | Can use fewer steps |
DDIM interpolations are preferred for research and analysis because they reveal the true structure of the latent space without stochastic noise interference.
Advanced: Spherical interpolation
For higher-dimensional latent spaces, spherical linear interpolation (slerp) can produce better results than linear interpolation:
def slerp(z0, z1, alpha):
"""
Spherical linear interpolation between z0 and z1.
Better preserves the norm of latent vectors, which can be important
for maintaining the typical variance of the noise distribution.
"""
# Normalize to unit sphere
z0_norm = z0 / z0.norm(dim=-1, keepdim=True)
z1_norm = z1 / z1.norm(dim=-1, keepdim=True)
# Compute angle
omega = torch.acos((z0_norm * z1_norm).sum(dim=-1, keepdim=True))
# Compute interpolation
so = torch.sin(omega)
return (torch.sin((1.0 - alpha) * omega) / so) * z0 + (torch.sin(alpha * omega) / so) * z1
# Usage
for a in alphas:
z = slerp(z0, z1, a)
x = ddim_sample_from_xt(diffusion, z)
cols.append((x + 1) / 2)
Use spherical interpolation when you notice that linear interpolation produces samples with unusual characteristics in the middle of the path. This is more common with high-resolution images or conditional models.