PyTorch’s torch.distributions.Distribution class provides an excellent interface for probability distributions, but it has a critical limitation: it doesn’t inherit from nn.Module.This creates problems when building conditional or parameterized distributions:
import torchimport torch.nn as nnfrom torch.distributions import Normal# This doesn't work as expected!class ConditionalNormal(nn.Module): def __init__(self, input_dim): super().__init__() self.net = nn.Linear(input_dim, 2) # Output mean and log_std def forward(self, context): params = self.net(context) mean, log_std = params.chunk(2, dim=-1) # Problem: This distribution isn't part of the module tree return Normal(mean, log_std.exp())
Distributions created inside the forward pass don’t have their parameters tracked by PyTorch’s optimizer or included in model.parameters().
Zuko introduces the LazyDistribution pattern to bridge this gap. A lazy distribution is an nn.Module that constructs and returns a distribution within its forward pass, given a context.
class LazyDistribution(nn.Module, abc.ABC): r"""Abstract lazy distribution. A lazy distribution is a module that builds and returns a distribution p(X | c) within its forward pass, given a context c. """ @abc.abstractmethod def forward(self, c: Tensor | None = None) -> Distribution: r""" Arguments: c: A context tensor. Returns: A distribution p(X | c). """ pass
The key insight: instead of returning samples or values, a lazy distribution returns a distribution object that can then be sampled or evaluated.
The “lazy” terminology comes from the fact that the distribution is not instantiated until forward() is called, allowing parameters to be computed from context.
The simplest case is an unconditional distribution with learnable parameters:
from zuko.lazy import UnconditionalDistributionfrom zuko.distributions import DiagNormalimport torch# Create an unconditional lazy normal distributionmu = torch.zeros(3)sigma = torch.ones(3)base = UnconditionalDistribution(DiagNormal, mu, sigma, buffer=True)# Call to get the actual distributiondist = base()print(dist) # DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))# Now we can samplesamples = dist.sample()print(samples) # tensor([ 1.5410, -0.2934, -2.1788])
By setting buffer=False, the parameters become trainable:
import torch.optim as optim# Create with trainable parametersmu = torch.zeros(2, requires_grad=True)sigma = torch.ones(2, requires_grad=True)base = UnconditionalDistribution(DiagNormal, mu, sigma, buffer=False)# Now parameters are part of the moduleoptimizer = optim.Adam(base.parameters(), lr=0.01)# Training loopfor _ in range(100): dist = base() samples = dist.rsample((32,)) loss = -dist.log_prob(samples).mean() # Negative log-likelihood optimizer.zero_grad() loss.backward() optimizer.step()
The UnconditionalDistribution class (defined in zuko/lazy.py:242-288) automatically registers tensor arguments as either buffers or parameters based on the buffer flag.
Lazy distributions compose naturally with PyTorch’s module system. You can stack them, wrap them in nn.Sequential, and combine them with other modules.
Parameter Management
All learnable parameters are automatically registered and tracked. No manual parameter management needed.
Context Flexibility
The same architecture can produce different distributions based on context, enabling powerful conditional models.
Training Integration
Works seamlessly with PyTorch optimizers, learning rate schedulers, and training loops.
class MyLazyDistribution(LazyDistribution): def __init__(self, ...): super().__init__() # Initialize modules and parameters self.net = nn.Sequential(...) def forward(self, c: Tensor | None = None) -> Distribution: # Compute distribution parameters from context params = self.net(c) if c is not None else self.default_params # Return a distribution instance return SomeDistribution(*params)