Skip to main content
Zuko provides several specialized distributions that extend PyTorch’s distribution framework for specific modeling scenarios.

Multivariate Distributions

DiagNormal

A multivariate normal distribution with diagonal covariance (independent variables).
DiagNormal(loc, scale, ndims=1)
loc
Tensor
required
The mean μ of the variables. Shape determines the dimensionality.
scale
Tensor
required
The standard deviation σ of the variables. Must have the same shape as loc.
ndims
int
default:"1"
The number of batch dimensions to interpret as event dimensions.
Example:
import torch
from zuko.distributions import DiagNormal

# 3-dimensional normal with independent components
d = DiagNormal(torch.zeros(3), torch.ones(3))
print(d.event_shape)  # torch.Size([3])

sample = d.sample()
print(sample)  # tensor([ 1.5410, -0.2934, -2.1788])

# Compute log probability
log_p = d.log_prob(sample)
print(log_p)  # scalar tensor
Unlike MultivariateNormal, DiagNormal assumes no correlation between variables, making it more efficient for high-dimensional spaces.

BoxUniform

A uniform distribution over an axis-aligned hypercube domain.
BoxUniform(lower, upper, ndims=1)
lower
Tensor
required
The lower bounds (inclusive) for each dimension.
upper
Tensor
required
The upper bounds (exclusive) for each dimension.
ndims
int
default:"1"
The number of batch dimensions to interpret as event dimensions.
Mathematical Definition: liXi<uil_i \leq X_i < u_i where lil_i and uiu_i are the lower and upper bounds in dimension ii. Example:
from zuko.distributions import BoxUniform

# Uniform over [-1, 1]³
d = BoxUniform(-torch.ones(3), torch.ones(3))
print(d.event_shape)  # torch.Size([3])

sample = d.sample()
print(sample)  # tensor([-0.0075,  0.5364, -0.8230])

Joint

Combines multiple independent distributions into a single multivariate distribution.
Joint(*marginals)
marginals
Distribution, ...
required
Variable number of distributions p(Z_i) to combine. The resulting distribution concatenates samples from each marginal.
Mathematical Definition: p(X=x)=ip(Zi=xi)p(X = x) = \prod_i p(Z_i = x_i) Example:
from torch.distributions import Uniform, Normal
from zuko.distributions import Joint

# Combine uniform and normal distributions
d = Joint(Uniform(0.0, 1.0), Normal(0.0, 1.0))
print(d.event_shape)  # torch.Size([2])

sample = d.sample()
print(sample)  # tensor([0.4963, 0.2072])

# Log probability is sum of marginal log probabilities
log_p = d.log_prob(sample)
Use Cases:
  • Combining heterogeneous random variables
  • Building complex priors from simple components
  • Creating distributions with different supports per dimension

Mixture and Composition

Mixture

A weighted mixture of distributions.
Mixture(base, logits)
base
Distribution
required
A batch of base distributions p(Z_i). The last batch dimension indexes the mixture components.
logits
Tensor
required
The unnormalized log-weights log w_i for each component. Shape must match the last dimension of base.batch_shape.
Mathematical Definition: p(X=x)=1iwiiwip(Zi=x)p(X = x) = \frac{1}{\sum_i w_i} \sum_i w_i \, p(Z_i = x) Example:
import torch
from torch.distributions import Normal
from zuko.distributions import Mixture

# Mixture of 3 Gaussians
means = torch.tensor([-2.0, 0.0, 2.0])
stds = torch.ones(3)
weights = torch.tensor([1.0, 2.0, 1.0])  # unnormalized

d = Mixture(
    Normal(means, stds),
    torch.log(weights)
)

print(d.event_shape)  # torch.Size([])
sample = d.sample()
print(sample)  # tensor(-1.6920)
The logits parameter uses log-space for numerical stability. Weights are normalized automatically via softmax.

TransformedUniform

A distribution where a transformation of the variable is uniformly distributed.
TransformedUniform(f, lower, upper)
f
Transform
required
A transformation f that is monotonically increasing over [l, u].
lower
Tensor
required
Lower bound l (inclusive) in the original space.
upper
Tensor
required
Upper bound u (exclusive) in the original space.
Mathematical Definition: p(X=x)=1f(u)f(l){f(x)if f(l)f(x)<f(u)0otherwisep(X = x) = \frac{1}{f(u) - f(l)} \begin{cases} f'(x) & \text{if } f(l) \leq f(x) < f(u) \\ 0 & \text{otherwise} \end{cases} Example:
from torch.distributions.transforms import ExpTransform
from zuko.distributions import TransformedUniform

# Uniform in log-space between e^(-1) and e^1
d = TransformedUniform(ExpTransform(), -1.0, 1.0)

sample = d.sample()
print(sample)  # tensor(0.4281)

log_p = d.log_prob(sample)

Truncated

Truncates a base distribution between specified bounds.
Truncated(base, lower=-inf, upper=+inf)
base
Distribution
required
A univariate base distribution p(X) to truncate.
lower
Tensor
default:"-inf"
Lower bound l (inclusive).
upper
Tensor
default:"+inf"
Upper bound u (exclusive).
Mathematical Definition: p(X=xlX<u)=1P(Xu)P(Xl){p(X=x)if lx<u0otherwisep(X = x | l \leq X < u) = \frac{1}{P(X \leq u) - P(X \leq l)} \begin{cases} p(X = x) & \text{if } l \leq x < u \\ 0 & \text{otherwise} \end{cases} Example:
from torch.distributions import Normal
from zuko.distributions import Truncated

# Standard normal truncated to [1, 2]
d = Truncated(Normal(0.0, 1.0), 1.0, 2.0)

sample = d.sample()
print(sample)  # tensor(1.3333) - always in [1, 2]
The base distribution must be univariate (no event dimensions). Use Independent to truncate multivariate distributions component-wise.

Shape-Based Distributions

GeneralizedNormal

A generalized normal distribution with shape parameter β.
GeneralizedNormal(beta)
beta
Tensor
required
The shape parameter β > 0. Controls the tail behavior (β=1 gives Laplace, β=2 gives Normal).
Mathematical Definition: p(X=x)=β2Γ(1/β)exp(xβ)p(X = x) = \frac{\beta}{2 \Gamma(1 / \beta)} \exp(-|x|^\beta) Example:
import torch
from zuko.distributions import GeneralizedNormal

# β=2 corresponds to normal with σ² = 1/2
d = GeneralizedNormal(2.0)
sample = d.sample()
print(sample)  # tensor(-0.0281)

# β=1 corresponds to Laplace distribution
d_laplace = GeneralizedNormal(1.0)

# β<1 gives heavy tails, β>2 gives light tails
d_heavy = GeneralizedNormal(0.5)
d_light = GeneralizedNormal(4.0)
The generalized normal distribution interpolates between different tail behaviors. The Gamma function normalizes the density.

Order Statistics

Sort

Distribution of n ordered draws from a base distribution.
Sort(base, n=2, descending=False)
base
Distribution
required
A univariate base distribution p(Z).
n
int
default:"2"
The number of draws.
descending
bool
default:"False"
Whether to sort in descending order.
Mathematical Definition: p(X=x)={n!i=1np(Z=xi)if x is ordered0otherwisep(X = x) = \begin{cases} n! \, \prod_{i = 1}^n p(Z = x_i) & \text{if $x$ is ordered} \\ 0 & \text{otherwise} \end{cases} Example:
from torch.distributions import Normal
from zuko.distributions import Sort

# 3 ordered draws from standard normal
d = Sort(Normal(0.0, 1.0), 3)
print(d.event_shape)  # torch.Size([3])

sample = d.sample()
print(sample)  # tensor([-2.1788, -0.2934,  1.5410]) - always ordered

# Descending order
d_desc = Sort(Normal(0.0, 1.0), 3, descending=True)
sample_desc = d_desc.sample()
print(sample_desc)  # tensor([1.5410, -0.2934, -2.1788])

TopK

Distribution of the top k elements among n draws.
TopK(base, k=1, n=2, descending=False)
base
Distribution
required
A univariate base distribution p(Z).
k
int
default:"1"
The number of selected elements. Must satisfy 1 ≤ k < n.
n
int
default:"2"
The total number of draws.
descending
bool
default:"False"
Whether to select top k in descending order (largest elements) or ascending order (smallest elements).
Mathematical Definition: p(X=x)={n!(nk)!i=1kp(Z=xi)P(Zxk)nkif x is ordered0otherwisep(X = x) = \begin{cases} \frac{n!}{(n - k)!} \, \prod_{i = 1}^k p(Z = x_i) \, P(Z \geq x_k)^{n - k} & \text{if $x$ is ordered} \\ 0 & \text{otherwise} \end{cases} Example:
from torch.distributions import Normal
from zuko.distributions import TopK

# Top 2 out of 5 draws (smallest values by default)
d = TopK(Normal(0.0, 1.0), k=2, n=5)
print(d.event_shape)  # torch.Size([2])

sample = d.sample()
print(sample)  # tensor([-2.1788, -0.2934]) - 2 smallest of 5

# Top 2 largest values
d_max = TopK(Normal(0.0, 1.0), k=2, n=5, descending=True)
sample_max = d_max.sample()
print(sample_max)  # tensor([1.5410, 0.8321]) - 2 largest of 5

Minimum

Distribution of the minimum among n draws.
Minimum(base, n=2)
base
Distribution
required
A univariate base distribution p(Z).
n
int
default:"2"
The number of draws.
Mathematical Definition: p(X=x)=np(Z=x)P(Zx)n1p(X = x) = n \, p(Z = x) \, P(Z \geq x)^{n - 1} Example:
from torch.distributions import Normal
from zuko.distributions import Minimum

# Minimum of 3 standard normal draws
d = Minimum(Normal(0.0, 1.0), 3)
print(d.event_shape)  # torch.Size([]) - scalar

sample = d.sample()
print(sample)  # tensor(-2.1788) - likely negative

# Minimum shifts the distribution toward smaller values
samples = d.sample((1000,))
print(samples.mean())  # approximately -0.85
The minimum distribution is useful for modeling worst-case scenarios or lower bounds in statistical applications.

Maximum

Distribution of the maximum among n draws.
Maximum(base, n=2)
base
Distribution
required
A univariate base distribution p(Z).
n
int
default:"2"
The number of draws.
Mathematical Definition: p(X=x)=np(Z=x)P(Zx)n1p(X = x) = n \, p(Z = x) \, P(Z \leq x)^{n - 1} Example:
from torch.distributions import Normal
from zuko.distributions import Maximum

# Maximum of 3 standard normal draws
d = Maximum(Normal(0.0, 1.0), 3)
print(d.event_shape)  # torch.Size([]) - scalar

sample = d.sample()
print(sample)  # tensor(1.5410) - likely positive

# Maximum shifts the distribution toward larger values
samples = d.sample((1000,))
print(samples.mean())  # approximately 0.85
Use Cases:
  • Extreme value theory
  • Best-case scenario modeling
  • Upper confidence bounds

Common Patterns

Batched Distributions

All distributions support batching for parallel computation:
import torch
from zuko.distributions import DiagNormal

# Batch of 10 different 3D normal distributions
loc = torch.randn(10, 3)
scale = torch.ones(10, 3)
d = DiagNormal(loc, scale)

print(d.batch_shape)   # torch.Size([10])
print(d.event_shape)   # torch.Size([3])

# Sample from all at once
samples = d.sample((100,))
print(samples.shape)   # torch.Size([100, 10, 3])

Combining Distributions

from torch.distributions import Normal, Beta
from zuko.distributions import Joint, Truncated

# Create a complex prior distribution
prior = Joint(
    Truncated(Normal(0.0, 1.0), 0.0, 2.0),  # positive normal
    Beta(2.0, 2.0),                          # beta in [0, 1]
    DiagNormal(torch.zeros(3), torch.ones(3))  # 3D normal
)

print(prior.event_shape)  # torch.Size([5]) - concatenated
sample = prior.sample()

Using with Normalizing Flows

from zuko.distributions import NormalizingFlow, DiagNormal
from zuko.transforms import MAF

# Use special distributions as base or target
base = DiagNormal(torch.zeros(5), torch.ones(5))
transform = MAF(features=5, context=0, hidden_features=[64, 64])
flow = NormalizingFlow(transform, base)

# Train the flow
data = target_distribution.sample((1000,))
loss = -flow.log_prob(data).mean()

See Also

Build docs developers (and LLMs) love