Skip to main content

Overview

Lazy distributions are the foundation of Zuko’s conditional distribution framework. A lazy distribution is a PyTorch module that builds and returns a torch.distributions.Distribution within its forward pass, given an optional context.
Lazy distributions solve a fundamental limitation in PyTorch: Distribution objects are not nn.Module instances, so their parameters cannot be easily moved to GPU or accessed via .parameters(). Lazy distributions wrap this functionality as modules.

LazyDistribution

The abstract base class for all lazy distributions.
class LazyDistribution(nn.Module, abc.ABC)
A lazy distribution builds and returns a distribution p(Xc)p(X | c) within its forward pass, given a context cc. This design enables conditional distributions while retaining all features of PyTorch modules.

Methods

forward

def forward(c: Tensor | None = None) -> Distribution
Builds and returns a conditional distribution.
c
Tensor | None
A context tensor. If None, the distribution is unconditional.
Returns: A torch.distributions.Distribution object representing p(Xc)p(X | c).

Creating Custom Lazy Distributions

To create a custom lazy distribution, subclass LazyDistribution and implement the forward method:
import torch
import torch.nn as nn
from torch.distributions import Normal
from zuko.lazy import LazyDistribution

class ConditionalNormal(LazyDistribution):
    def __init__(self, features: int, context: int):
        super().__init__()
        self.mu_net = nn.Linear(context, features)
        self.sigma_net = nn.Sequential(
            nn.Linear(context, features),
            nn.Softplus()
        )
    
    def forward(self, c: torch.Tensor | None = None) -> Normal:
        if c is None:
            raise ValueError("Context is required")
        
        mu = self.mu_net(c)
        sigma = self.sigma_net(c)
        
        return Normal(mu, sigma)

# Usage
dist = ConditionalNormal(features=3, context=5)
context = torch.randn(10, 5)  # Batch of 10 contexts
p_x_given_c = dist(context)  # Returns a Normal distribution
samples = p_x_given_c.sample()  # Sample from p(X|c)

UnconditionalDistribution

A convenience class for creating unconditional lazy distributions from distribution constructors.
class UnconditionalDistribution(LazyDistribution)
UnconditionalDistribution wraps any distribution constructor and registers its arguments as buffers or parameters, making them part of the module’s state.

Constructor

def __init__(
    f: Callable[..., Distribution],
    *args,
    buffer: bool = False,
    **kwargs
)
f
Callable[..., Distribution]
required
A distribution constructor (e.g., torch.distributions.Normal or a custom distribution class). If f is a module, it is registered as a submodule.
args
Any
Positional arguments passed to f. Tensor arguments are registered as buffers or parameters.
buffer
bool
default:"False"
Whether tensor arguments are registered as buffers (not trainable) or parameters (trainable).
kwargs
Any
Keyword arguments passed to f. Tensor arguments are registered as buffers or parameters.

Methods

forward

def forward(c: Tensor | None = None) -> Distribution
Returns the distribution by calling f(*args, **kwargs). The context argument c is always ignored.
c
Tensor | None
A context tensor. This argument is ignored for unconditional distributions.
Returns: f(*args, **kwargs) - the constructed distribution.

Examples

Using with DiagNormal

import torch
from zuko.lazy import UnconditionalDistribution
from zuko.distributions import DiagNormal

# Create an unconditional diagonal normal distribution
mu = torch.zeros(3)
sigma = torch.ones(3)

base = UnconditionalDistribution(
    DiagNormal,
    mu,
    sigma,
    buffer=True  # Register as buffers (not trainable)
)

# Get the distribution (context is ignored)
dist = base()
print(dist)  # DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))

# Sample from the distribution
samples = dist.sample()
print(samples)  # tensor([ 1.5410, -0.2934, -2.1788])

Using with PyTorch distributions

import torch
from torch.distributions import Normal
from zuko.lazy import UnconditionalDistribution

# Create an unconditional normal distribution with trainable parameters
loc = torch.zeros(5, requires_grad=True)
scale = torch.ones(5, requires_grad=True)

base = UnconditionalDistribution(
    Normal,
    loc,
    scale,
    buffer=False  # Register as parameters (trainable)
)

# The parameters are now part of the module
print(list(base.parameters()))  # [loc, scale]

# Use in a normalizing flow
from zuko.lazy import Flow

flow = Flow(
    transform=my_transform,
    base=base
)

As a base distribution in a flow

import torch
from zuko.lazy import Flow, UnconditionalDistribution
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal

# Define the base distribution
base = UnconditionalDistribution(
    DiagNormal,
    loc=torch.zeros(3),
    scale=torch.ones(3),
    buffer=True
)

# Create a flow with the unconditional base
flow = Flow(
    transform=MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
    base=base
)

# Use the flow
context = torch.randn(10, 5)
dist = flow(context)  # Returns a conditional distribution
samples = dist.sample((100,))  # Sample 100 points from p(x|c)
When using buffer=True, the tensor arguments become part of the module’s state but are not trainable. This is useful for fixed parameters like the standard normal base distribution. Use buffer=False when you want the parameters to be optimized during training.
UnconditionalDistribution ignores the context argument in its forward method. Even if you pass a context, it will always return the same distribution. Use conditional lazy distributions (like those created by MaskedAutoregressiveTransform) when you need context-dependent distributions.

See Also

  • LazyTransform - The transform counterpart to lazy distributions
  • Flow - Combines lazy transforms and distributions into normalizing flows
  • Distributions - Built-in distribution implementations

Build docs developers (and LLMs) love