Skip to main content
Latent space interpolation reveals the structure of the learned data distribution by smoothly transitioning between two random noise vectors. This is particularly effective with DDIM’s deterministic sampling.

Concept

Instead of sampling from independent noise vectors, we can interpolate between two noise vectors z0z_0 and z1z_1 at timestep TT: zα=(1α)z0+αz1,α[0,1]z_\alpha = (1 - \alpha) z_0 + \alpha z_1, \quad \alpha \in [0, 1] Then run the full reverse diffusion process from each interpolated zαz_\alpha to generate the final images.
DDIM’s deterministic trajectories (η=0\eta = 0) produce smooth, meaningful interpolations. DDPM’s stochastic sampling can lead to less coherent transitions.

DDPM interpolation

This implementation uses standard DDPM sampling for each interpolation point. From src/utilities/interpolation_and_timesteps.py:76:
@torch.no_grad()
def interpolate_noise_and_generate(diffusion, n=8, steps=7, save_path='interp.png'):
    """
    Interpolate between two random noise vectors and generate images.
    
    Args:
        diffusion: DiffusionProcess instance
        n: Number of samples in each column
        steps: Number of interpolation steps (columns in the grid)
        save_path: Where to save the output grid
    
    Returns:
        Path to saved image grid
    """
    # Two random noise endpoints at T
    z0 = torch.randn(
        n,
        diffusion.model.channels,
        diffusion.model.image_size,
        diffusion.model.image_size,
        device=next(diffusion.model.parameters()).device
    )
    z1 = torch.randn_like(z0)
    
    # Linear interpolation coefficients
    alphas = torch.linspace(0, 1, steps, device=z0.device)
    
    cols = []
    for a in alphas:
        z = (1-a)*z0 + a*z1
        x = sample_from_xt(diffusion, z)
        cols.append((x+1)/2)  # Normalize to [0, 1]
    
    from torchvision.utils import save_image
    grid = torch.cat(cols, dim=0)
    save_image(grid, save_path, nrow=n)  # Each column is a different α
    return save_path
The sample_from_xt helper performs the reverse process from a given noise vector:
@torch.no_grad()
def sample_from_xt(diffusion, x_T):
    """Sample from the diffusion process starting from x_T."""
    model, T, dev = diffusion.model, diffusion.noise_steps, x_T.device
    x_t = x_T.clone()
    
    for t in reversed(range(T)):
        t_b = torch.full((x_t.size(0),), t, device=dev, dtype=torch.long)
        pred_eps = model(x_t, t_b)
        
        beta_t = diffusion.beta_schedule[t]
        alpha_t = diffusion.alpha_schedule[t]
        sqrt_one_minus_alpha_cumprod_t = diffusion.sqrt_one_minus_alpha_cumprod[t]
        
        x_t = (1/torch.sqrt(alpha_t)) * (x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * pred_eps)
        
        if t > 0:
            x_t += torch.sqrt(beta_t) * torch.randn_like(x_t)
    
    return x_t

DDIM interpolation

DDIM produces smoother interpolations due to deterministic trajectories. From src/utilities/interpolation_and_timesteps.py:115:
@torch.no_grad()
def interpolate_noise_and_generate_ddim(diffusion, n=8, steps=7, save_path="interp.png"):
    """
    Interpolation grid using deterministic DDIM sampling from x_T.
    
    Args:
        diffusion: DiffusionProcess instance
        n: Number of samples in each column
        steps: Number of interpolation steps
        save_path: Output path
    
    Returns:
        Path to saved image grid
    """
    dev = next(diffusion.model.parameters()).device
    C, H, W = (
        diffusion.model.channels,
        diffusion.model.image_size,
        diffusion.model.image_size,
    )

    z0 = torch.randn(n, C, H, W, device=dev)
    z1 = torch.randn_like(z0)
    alphas = torch.linspace(0, 1, steps, device=dev)

    cols = []
    for a in alphas:
        z = (1 - a) * z0 + a * z1
        x = ddim_sample_from_xt(diffusion, z)
        cols.append((x + 1) / 2)

    from torchvision.utils import save_image
    grid = torch.cat(cols, dim=0)
    save_image(grid, save_path, nrow=n)
    print(f"Saved {save_path}")
    return save_path
The DDIM sampling helper uses deterministic updates:
@torch.no_grad()
def ddim_sample_from_xt(diffusion, x_T):
    """Deterministic (η=0) reverse process for a batch of initial noises x_T."""
    model, T, dev = diffusion.model, diffusion.noise_steps, x_T.device
    x_t = x_T.clone()
    abar = diffusion.alpha_cumprod  # \bar{α}_t
    
    for t in reversed(range(T)):
        t_b = torch.full((x_t.size(0),), t, device=dev, dtype=torch.long)
        eps_hat = model(x_t, t_b)
        
        # x0 prediction (standard ε-prediction parameterization)
        x0_hat = (x_t - torch.sqrt(1 - abar[t]) * eps_hat) / torch.sqrt(abar[t])
        
        if t > 0:
            x_t = torch.sqrt(abar[t-1]) * x0_hat + torch.sqrt(1 - abar[t-1]) * eps_hat
        else:
            x_t = x0_hat
    
    return x_t

Usage example

1

Load trained model

Set up the diffusion process with a trained checkpoint:
import os
import torch
from src.models.diffusion import DiffusionProcess

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

diffusion = DiffusionProcess(
    image_size=28,
    channels=1,
    hidden_dims=[128, 256, 512],
    device=device,
)

diffusion.model.load_state_dict(
    torch.load('best_model.pt', map_location=device)
)
diffusion.model.eval()
2

Generate DDPM interpolation

Create an interpolation grid using stochastic DDPM:
from src.utilities.interpolation_and_timesteps import interpolate_noise_and_generate

interpolate_noise_and_generate(
    diffusion,
    n=8,              # 8 samples per column
    steps=9,          # 9 interpolation steps
    save_path='samples/interp.png'
)
3

Generate DDIM interpolation

Create a smoother interpolation with deterministic DDIM:
from src.utilities.interpolation_and_timesteps import interpolate_noise_and_generate_ddim

interpolate_noise_and_generate_ddim(
    diffusion,
    n=8,
    steps=9,
    save_path='samples/interp_ddim.png'
)

Interpreting results

The generated grid has:
  • Rows: Independent samples (different random seeds within the batch)
  • Columns: Interpolation steps from α=0\alpha = 0 (left, pure z0z_0) to α=1\alpha = 1 (right, pure z1z_1)
Look for smooth semantic transitions in DDIM interpolations. For MNIST, you might see gradual digit morphing (e.g., 3 → 8). For CIFAR-10, you might see color/texture changes or object transformations.

Comparing DDPM vs DDIM interpolations

AspectDDPMDDIM
SmoothnessLess smooth, stochastic jumpsVery smooth transitions
ReproducibilityDifferent every runIdentical every run
Semantic meaningHarder to interpretClear semantic paths
SpeedAll 1000 steps requiredCan use fewer steps
DDIM interpolations are preferred for research and analysis because they reveal the true structure of the latent space without stochastic noise interference.

Advanced: Spherical interpolation

For higher-dimensional latent spaces, spherical linear interpolation (slerp) can produce better results than linear interpolation:
def slerp(z0, z1, alpha):
    """
    Spherical linear interpolation between z0 and z1.
    
    Better preserves the norm of latent vectors, which can be important
    for maintaining the typical variance of the noise distribution.
    """
    # Normalize to unit sphere
    z0_norm = z0 / z0.norm(dim=-1, keepdim=True)
    z1_norm = z1 / z1.norm(dim=-1, keepdim=True)
    
    # Compute angle
    omega = torch.acos((z0_norm * z1_norm).sum(dim=-1, keepdim=True))
    
    # Compute interpolation
    so = torch.sin(omega)
    return (torch.sin((1.0 - alpha) * omega) / so) * z0 + (torch.sin(alpha * omega) / so) * z1

# Usage
for a in alphas:
    z = slerp(z0, z1, a)
    x = ddim_sample_from_xt(diffusion, z)
    cols.append((x + 1) / 2)
Use spherical interpolation when you notice that linear interpolation produces samples with unusual characteristics in the middle of the path. This is more common with high-resolution images or conditional models.

Build docs developers (and LLMs) love