Skip to main content
The zuko.distributions module provides a collection of probability distributions that extend PyTorch’s distribution framework. These distributions are designed to work seamlessly with normalizing flows and support various probabilistic modeling tasks.

Core Distribution

NormalizingFlow

The primary distribution for implementing normalizing flows through transformations.
  • NormalizingFlow - Transform a base distribution through an invertible transformation

Special Distributions

Zuko provides several specialized distributions for different modeling scenarios:

Multivariate Distributions

  • DiagNormal - Multivariate normal with diagonal covariance
  • BoxUniform - Uniform distribution over a hypercube
  • Joint - Concatenation of independent random variables

Mixture and Composition

Shape-Based Distributions

Order Statistics

  • Sort - Ordered draws from a base distribution
  • TopK - Top k elements from n draws
  • Minimum - Minimum of n draws
  • Maximum - Maximum of n draws

Exported Classes

All distributions are exported from zuko.distributions:
from zuko.distributions import (
    BoxUniform,
    DiagNormal,
    GeneralizedNormal,
    Joint,
    Maximum,
    Minimum,
    Mixture,
    NormalizingFlow,
    Sort,
    TopK,
    TransformedUniform,
    Truncated,
)

Key Features

  • PyTorch Integration: All distributions inherit from torch.distributions.Distribution
  • Differentiable Sampling: Most distributions support rsample() for reparameterized gradients
  • Batch Operations: Full support for batched sampling and probability computation
  • Composability: Distributions can be combined and transformed flexibly

Common Methods

All distributions in this module support standard PyTorch distribution methods:
  • sample(shape) - Draw samples from the distribution
  • rsample(shape) - Draw reparameterized samples (when supported)
  • log_prob(x) - Compute log probability density/mass
  • batch_shape - Shape of the batch dimensions
  • event_shape - Shape of a single sample
  • expand(batch_shape) - Expand to a new batch shape
Zuko sets Distribution._validate_args = False by default for performance. Enable validation manually if needed for debugging.

Build docs developers (and LLMs) love