Skip to main content

The Problem: PyTorch Distributions Aren’t Modules

PyTorch’s torch.distributions.Distribution class provides an excellent interface for probability distributions, but it has a critical limitation: it doesn’t inherit from nn.Module. This creates problems when building conditional or parameterized distributions:
import torch
import torch.nn as nn
from torch.distributions import Normal

# This doesn't work as expected!
class ConditionalNormal(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Linear(input_dim, 2)  # Output mean and log_std
    
    def forward(self, context):
        params = self.net(context)
        mean, log_std = params.chunk(2, dim=-1)
        # Problem: This distribution isn't part of the module tree
        return Normal(mean, log_std.exp())
Distributions created inside the forward pass don’t have their parameters tracked by PyTorch’s optimizer or included in model.parameters().

The Solution: LazyDistribution

Zuko introduces the LazyDistribution pattern to bridge this gap. A lazy distribution is an nn.Module that constructs and returns a distribution within its forward pass, given a context.

The Abstract Base Class

From zuko/lazy.py:29-50:
class LazyDistribution(nn.Module, abc.ABC):
    r"""Abstract lazy distribution.
    
    A lazy distribution is a module that builds and returns a distribution
    p(X | c) within its forward pass, given a context c.
    """
    
    @abc.abstractmethod
    def forward(self, c: Tensor | None = None) -> Distribution:
        r"""
        Arguments:
            c: A context tensor.
        
        Returns:
            A distribution p(X | c).
        """
        pass
The key insight: instead of returning samples or values, a lazy distribution returns a distribution object that can then be sampled or evaluated.
The “lazy” terminology comes from the fact that the distribution is not instantiated until forward() is called, allowing parameters to be computed from context.

UnconditionalDistribution

The simplest case is an unconditional distribution with learnable parameters:
from zuko.lazy import UnconditionalDistribution
from zuko.distributions import DiagNormal
import torch

# Create an unconditional lazy normal distribution
mu = torch.zeros(3)
sigma = torch.ones(3)
base = UnconditionalDistribution(DiagNormal, mu, sigma, buffer=True)

# Call to get the actual distribution
dist = base()
print(dist)  # DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))

# Now we can sample
samples = dist.sample()
print(samples)  # tensor([ 1.5410, -0.2934, -2.1788])

Learnable Parameters

By setting buffer=False, the parameters become trainable:
import torch.optim as optim

# Create with trainable parameters
mu = torch.zeros(2, requires_grad=True)
sigma = torch.ones(2, requires_grad=True)
base = UnconditionalDistribution(DiagNormal, mu, sigma, buffer=False)

# Now parameters are part of the module
optimizer = optim.Adam(base.parameters(), lr=0.01)

# Training loop
for _ in range(100):
    dist = base()
    samples = dist.rsample((32,))
    loss = -dist.log_prob(samples).mean()  # Negative log-likelihood
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
The UnconditionalDistribution class (defined in zuko/lazy.py:242-288) automatically registers tensor arguments as either buffers or parameters based on the buffer flag.

Building Conditional Distributions

Lazy distributions shine when building context-dependent distributions:
import torch.nn as nn
from zuko.lazy import LazyDistribution
from torch.distributions import Normal, Distribution

class ConditionalNormal(LazyDistribution):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 2 * output_dim),
        )
        self.output_dim = output_dim
    
    def forward(self, c: torch.Tensor | None = None) -> Distribution:
        params = self.net(c)
        mean, log_std = params.chunk(2, dim=-1)
        return Normal(mean, log_std.exp())

# Usage
model = ConditionalNormal(input_dim=5, output_dim=3)
context = torch.randn(32, 5)  # Batch of 32 contexts

# Get conditional distribution
dist = model(context)
print(dist.batch_shape)  # torch.Size([32])
print(dist.event_shape)  # torch.Size([3])

# Sample and evaluate
samples = dist.sample()
log_prob = dist.log_prob(samples)

Why This Works

With LazyDistribution:
  1. ✅ The neural network is an nn.Module with tracked parameters
  2. model.parameters() includes all trainable weights
  3. ✅ Gradients flow through the distribution parameters
  4. ✅ The distribution is created fresh for each context

The Flow Class

Zuko extends this pattern to normalizing flows with the Flow class:
from zuko.lazy import Flow, UnconditionalDistribution, UnconditionalTransform
from zuko.distributions import DiagNormal
from zuko.transforms import ExpTransform
import torch

# Build a lazy flow
base = UnconditionalDistribution(
    DiagNormal,
    torch.zeros(2),
    torch.ones(2),
    buffer=True
)

transform = UnconditionalTransform(ExpTransform)

flow = Flow(transform, base)

# Get the actual normalizing flow
dist = flow()
print(type(dist))  # <class 'zuko.distributions.NormalizingFlow'>

# Use it like any distribution
samples = dist.sample((100,))
log_prob = dist.log_prob(samples)

Conditional Flows

The real power emerges with conditional flows:
from zuko.flows import NSF  # Neural Spline Flow

# Create a conditional flow
flow = NSF(features=2, context=5, transforms=3)

# Different distributions for different contexts
context1 = torch.randn(5)
context2 = torch.randn(5)

dist1 = flow(context1)
dist2 = flow(context2)

# Same architecture, different distributions!
samples1 = dist1.sample((100,))
samples2 = dist2.sample((100,))

LazyTransform: The Transformation Analog

Just like distributions, transformations can be lazy:
from zuko.lazy import LazyTransform
from torch.distributions import Transform
import torch.nn as nn

class ConditionalAffine(LazyTransform):
    def __init__(self, input_dim: int, features: int):
        super().__init__()
        self.net = nn.Linear(input_dim, 2 * features)
    
    def forward(self, c: torch.Tensor | None = None) -> Transform:
        params = self.net(c)
        shift, log_scale = params.chunk(2, dim=-1)
        return AffineTransform(shift, log_scale.exp())
Lazy transformations enable context-dependent or parameterized transformations where the transformation itself varies based on input.

Key Advantages

Lazy distributions compose naturally with PyTorch’s module system. You can stack them, wrap them in nn.Sequential, and combine them with other modules.
All learnable parameters are automatically registered and tracked. No manual parameter management needed.
The same architecture can produce different distributions based on context, enabling powerful conditional models.
Works seamlessly with PyTorch optimizers, learning rate schedulers, and training loops.

Pattern Summary

The lazy pattern follows a simple structure:
class MyLazyDistribution(LazyDistribution):
    def __init__(self, ...):
        super().__init__()
        # Initialize modules and parameters
        self.net = nn.Sequential(...)
    
    def forward(self, c: Tensor | None = None) -> Distribution:
        # Compute distribution parameters from context
        params = self.net(c) if c is not None else self.default_params
        
        # Return a distribution instance
        return SomeDistribution(*params)
Key principles:
  1. Inherit from LazyDistribution (or LazyTransform)
  2. Initialize all learnable modules in __init__
  3. Return a distribution instance from forward()
  4. Accept optional context as input

Comparison Table

FeaturePyTorch DistributionLazyDistribution
Trainable parameters❌ Manual management✅ Automatic tracking
Context-dependent❌ External logic✅ Built-in
Composable⚠️ Limited✅ Full nn.Module
Gradient flow⚠️ Requires care✅ Automatic
Optimizer compatible❌ Not directly✅ Yes

Next Steps

Transformations

Explore lazy transformations and bijective mappings

Flow Models

See pre-built flow architectures using lazy patterns

Build docs developers (and LLMs) love