zuko.distributions module provides a collection of probability distributions that extend PyTorch’s distribution framework. These distributions are designed to work seamlessly with normalizing flows and support various probabilistic modeling tasks.
Core Distribution
NormalizingFlow
The primary distribution for implementing normalizing flows through transformations.- NormalizingFlow - Transform a base distribution through an invertible transformation
Special Distributions
Zuko provides several specialized distributions for different modeling scenarios:Multivariate Distributions
- DiagNormal - Multivariate normal with diagonal covariance
- BoxUniform - Uniform distribution over a hypercube
- Joint - Concatenation of independent random variables
Mixture and Composition
- Mixture - Weighted mixture of distributions
- TransformedUniform - Uniformly distributed after transformation
- Truncated - Truncate a distribution between bounds
Shape-Based Distributions
- GeneralizedNormal - Generalized normal with shape parameter β
Order Statistics
- Sort - Ordered draws from a base distribution
- TopK - Top k elements from n draws
- Minimum - Minimum of n draws
- Maximum - Maximum of n draws
Exported Classes
All distributions are exported fromzuko.distributions:
Key Features
- PyTorch Integration: All distributions inherit from
torch.distributions.Distribution - Differentiable Sampling: Most distributions support
rsample()for reparameterized gradients - Batch Operations: Full support for batched sampling and probability computation
- Composability: Distributions can be combined and transformed flexibly
Common Methods
All distributions in this module support standard PyTorch distribution methods:sample(shape)- Draw samples from the distributionrsample(shape)- Draw reparameterized samples (when supported)log_prob(x)- Compute log probability density/massbatch_shape- Shape of the batch dimensionsevent_shape- Shape of a single sampleexpand(batch_shape)- Expand to a new batch shape
Zuko sets
Distribution._validate_args = False by default for performance. Enable validation manually if needed for debugging.