Zuko provides several specialized distributions that extend PyTorch’s distribution framework for specific modeling scenarios.
Multivariate Distributions
DiagNormal
A multivariate normal distribution with diagonal covariance (independent variables).
DiagNormal(loc, scale, ndims=1)
The mean μ of the variables. Shape determines the dimensionality.
The standard deviation σ of the variables. Must have the same shape as loc.
The number of batch dimensions to interpret as event dimensions.
Example:
import torch
from zuko.distributions import DiagNormal
# 3-dimensional normal with independent components
d = DiagNormal(torch.zeros(3), torch.ones(3))
print(d.event_shape) # torch.Size([3])
sample = d.sample()
print(sample) # tensor([ 1.5410, -0.2934, -2.1788])
# Compute log probability
log_p = d.log_prob(sample)
print(log_p) # scalar tensor
Unlike MultivariateNormal, DiagNormal assumes no correlation between variables, making it more efficient for high-dimensional spaces.
A uniform distribution over an axis-aligned hypercube domain.
BoxUniform(lower, upper, ndims=1)
The lower bounds (inclusive) for each dimension.
The upper bounds (exclusive) for each dimension.
The number of batch dimensions to interpret as event dimensions.
Mathematical Definition:
li≤Xi<ui
where li and ui are the lower and upper bounds in dimension i.
Example:
from zuko.distributions import BoxUniform
# Uniform over [-1, 1]³
d = BoxUniform(-torch.ones(3), torch.ones(3))
print(d.event_shape) # torch.Size([3])
sample = d.sample()
print(sample) # tensor([-0.0075, 0.5364, -0.8230])
Joint
Combines multiple independent distributions into a single multivariate distribution.
marginals
Distribution, ...
required
Variable number of distributions p(Z_i) to combine. The resulting distribution concatenates samples from each marginal.
Mathematical Definition:
p(X=x)=i∏p(Zi=xi)
Example:
from torch.distributions import Uniform, Normal
from zuko.distributions import Joint
# Combine uniform and normal distributions
d = Joint(Uniform(0.0, 1.0), Normal(0.0, 1.0))
print(d.event_shape) # torch.Size([2])
sample = d.sample()
print(sample) # tensor([0.4963, 0.2072])
# Log probability is sum of marginal log probabilities
log_p = d.log_prob(sample)
Use Cases:
- Combining heterogeneous random variables
- Building complex priors from simple components
- Creating distributions with different supports per dimension
Mixture and Composition
Mixture
A weighted mixture of distributions.
A batch of base distributions p(Z_i). The last batch dimension indexes the mixture components.
The unnormalized log-weights log w_i for each component. Shape must match the last dimension of base.batch_shape.
Mathematical Definition:
p(X=x)=∑iwi1i∑wip(Zi=x)
Example:
import torch
from torch.distributions import Normal
from zuko.distributions import Mixture
# Mixture of 3 Gaussians
means = torch.tensor([-2.0, 0.0, 2.0])
stds = torch.ones(3)
weights = torch.tensor([1.0, 2.0, 1.0]) # unnormalized
d = Mixture(
Normal(means, stds),
torch.log(weights)
)
print(d.event_shape) # torch.Size([])
sample = d.sample()
print(sample) # tensor(-1.6920)
The logits parameter uses log-space for numerical stability. Weights are normalized automatically via softmax.
A distribution where a transformation of the variable is uniformly distributed.
TransformedUniform(f, lower, upper)
A transformation f that is monotonically increasing over [l, u].
Lower bound l (inclusive) in the original space.
Upper bound u (exclusive) in the original space.
Mathematical Definition:
p(X=x)=f(u)−f(l)1{f′(x)0if f(l)≤f(x)<f(u)otherwise
Example:
from torch.distributions.transforms import ExpTransform
from zuko.distributions import TransformedUniform
# Uniform in log-space between e^(-1) and e^1
d = TransformedUniform(ExpTransform(), -1.0, 1.0)
sample = d.sample()
print(sample) # tensor(0.4281)
log_p = d.log_prob(sample)
Truncated
Truncates a base distribution between specified bounds.
Truncated(base, lower=-inf, upper=+inf)
A univariate base distribution p(X) to truncate.
Lower bound l (inclusive).
Upper bound u (exclusive).
Mathematical Definition:
p(X=x∣l≤X<u)=P(X≤u)−P(X≤l)1{p(X=x)0if l≤x<uotherwise
Example:
from torch.distributions import Normal
from zuko.distributions import Truncated
# Standard normal truncated to [1, 2]
d = Truncated(Normal(0.0, 1.0), 1.0, 2.0)
sample = d.sample()
print(sample) # tensor(1.3333) - always in [1, 2]
The base distribution must be univariate (no event dimensions). Use Independent to truncate multivariate distributions component-wise.
Shape-Based Distributions
GeneralizedNormal
A generalized normal distribution with shape parameter β.
The shape parameter β > 0. Controls the tail behavior (β=1 gives Laplace, β=2 gives Normal).
Mathematical Definition:
p(X=x)=2Γ(1/β)βexp(−∣x∣β)
Example:
import torch
from zuko.distributions import GeneralizedNormal
# β=2 corresponds to normal with σ² = 1/2
d = GeneralizedNormal(2.0)
sample = d.sample()
print(sample) # tensor(-0.0281)
# β=1 corresponds to Laplace distribution
d_laplace = GeneralizedNormal(1.0)
# β<1 gives heavy tails, β>2 gives light tails
d_heavy = GeneralizedNormal(0.5)
d_light = GeneralizedNormal(4.0)
The generalized normal distribution interpolates between different tail behaviors. The Gamma function normalizes the density.
Order Statistics
Sort
Distribution of n ordered draws from a base distribution.
Sort(base, n=2, descending=False)
A univariate base distribution p(Z).
Whether to sort in descending order.
Mathematical Definition:
p(X=x)={n!∏i=1np(Z=xi)0if x is orderedotherwise
Example:
from torch.distributions import Normal
from zuko.distributions import Sort
# 3 ordered draws from standard normal
d = Sort(Normal(0.0, 1.0), 3)
print(d.event_shape) # torch.Size([3])
sample = d.sample()
print(sample) # tensor([-2.1788, -0.2934, 1.5410]) - always ordered
# Descending order
d_desc = Sort(Normal(0.0, 1.0), 3, descending=True)
sample_desc = d_desc.sample()
print(sample_desc) # tensor([1.5410, -0.2934, -2.1788])
TopK
Distribution of the top k elements among n draws.
TopK(base, k=1, n=2, descending=False)
A univariate base distribution p(Z).
The number of selected elements. Must satisfy 1 ≤ k < n.
The total number of draws.
Whether to select top k in descending order (largest elements) or ascending order (smallest elements).
Mathematical Definition:
p(X=x)={(n−k)!n!∏i=1kp(Z=xi)P(Z≥xk)n−k0if x is orderedotherwise
Example:
from torch.distributions import Normal
from zuko.distributions import TopK
# Top 2 out of 5 draws (smallest values by default)
d = TopK(Normal(0.0, 1.0), k=2, n=5)
print(d.event_shape) # torch.Size([2])
sample = d.sample()
print(sample) # tensor([-2.1788, -0.2934]) - 2 smallest of 5
# Top 2 largest values
d_max = TopK(Normal(0.0, 1.0), k=2, n=5, descending=True)
sample_max = d_max.sample()
print(sample_max) # tensor([1.5410, 0.8321]) - 2 largest of 5
Minimum
Distribution of the minimum among n draws.
A univariate base distribution p(Z).
Mathematical Definition:
p(X=x)=np(Z=x)P(Z≥x)n−1
Example:
from torch.distributions import Normal
from zuko.distributions import Minimum
# Minimum of 3 standard normal draws
d = Minimum(Normal(0.0, 1.0), 3)
print(d.event_shape) # torch.Size([]) - scalar
sample = d.sample()
print(sample) # tensor(-2.1788) - likely negative
# Minimum shifts the distribution toward smaller values
samples = d.sample((1000,))
print(samples.mean()) # approximately -0.85
The minimum distribution is useful for modeling worst-case scenarios or lower bounds in statistical applications.
Maximum
Distribution of the maximum among n draws.
A univariate base distribution p(Z).
Mathematical Definition:
p(X=x)=np(Z=x)P(Z≤x)n−1
Example:
from torch.distributions import Normal
from zuko.distributions import Maximum
# Maximum of 3 standard normal draws
d = Maximum(Normal(0.0, 1.0), 3)
print(d.event_shape) # torch.Size([]) - scalar
sample = d.sample()
print(sample) # tensor(1.5410) - likely positive
# Maximum shifts the distribution toward larger values
samples = d.sample((1000,))
print(samples.mean()) # approximately 0.85
Use Cases:
- Extreme value theory
- Best-case scenario modeling
- Upper confidence bounds
Common Patterns
Batched Distributions
All distributions support batching for parallel computation:
import torch
from zuko.distributions import DiagNormal
# Batch of 10 different 3D normal distributions
loc = torch.randn(10, 3)
scale = torch.ones(10, 3)
d = DiagNormal(loc, scale)
print(d.batch_shape) # torch.Size([10])
print(d.event_shape) # torch.Size([3])
# Sample from all at once
samples = d.sample((100,))
print(samples.shape) # torch.Size([100, 10, 3])
Combining Distributions
from torch.distributions import Normal, Beta
from zuko.distributions import Joint, Truncated
# Create a complex prior distribution
prior = Joint(
Truncated(Normal(0.0, 1.0), 0.0, 2.0), # positive normal
Beta(2.0, 2.0), # beta in [0, 1]
DiagNormal(torch.zeros(3), torch.ones(3)) # 3D normal
)
print(prior.event_shape) # torch.Size([5]) - concatenated
sample = prior.sample()
Using with Normalizing Flows
from zuko.distributions import NormalizingFlow, DiagNormal
from zuko.transforms import MAF
# Use special distributions as base or target
base = DiagNormal(torch.zeros(5), torch.ones(5))
transform = MAF(features=5, context=0, hidden_features=[64, 64])
flow = NormalizingFlow(transform, base)
# Train the flow
data = target_distribution.sample((1000,))
loss = -flow.log_prob(data).mean()
See Also