Overview
Lazy distributions are the foundation of Zuko’s conditional distribution framework. A lazy distribution is a PyTorch module that builds and returns atorch.distributions.Distribution within its forward pass, given an optional context.
Lazy distributions solve a fundamental limitation in PyTorch:
Distribution objects are not nn.Module instances, so their parameters cannot be easily moved to GPU or accessed via .parameters(). Lazy distributions wrap this functionality as modules.LazyDistribution
The abstract base class for all lazy distributions.Methods
forward
A context tensor. If
None, the distribution is unconditional.torch.distributions.Distribution object representing .
Creating Custom Lazy Distributions
To create a custom lazy distribution, subclassLazyDistribution and implement the forward method:
UnconditionalDistribution
A convenience class for creating unconditional lazy distributions from distribution constructors.UnconditionalDistribution wraps any distribution constructor and registers its arguments as buffers or parameters, making them part of the module’s state.
Constructor
A distribution constructor (e.g.,
torch.distributions.Normal or a custom distribution class). If f is a module, it is registered as a submodule.Positional arguments passed to
f. Tensor arguments are registered as buffers or parameters.Whether tensor arguments are registered as buffers (not trainable) or parameters (trainable).
Keyword arguments passed to
f. Tensor arguments are registered as buffers or parameters.Methods
forward
f(*args, **kwargs). The context argument c is always ignored.
A context tensor. This argument is ignored for unconditional distributions.
f(*args, **kwargs) - the constructed distribution.
Examples
Using with DiagNormal
Using with PyTorch distributions
As a base distribution in a flow
When using
buffer=True, the tensor arguments become part of the module’s state but are not trainable. This is useful for fixed parameters like the standard normal base distribution. Use buffer=False when you want the parameters to be optimized during training.See Also
- LazyTransform - The transform counterpart to lazy distributions
- Flow - Combines lazy transforms and distributions into normalizing flows
- Distributions - Built-in distribution implementations
