Skip to main content
The NormalizingFlow class implements a normalizing flow for a random variable X towards a base distribution p(Z) through an invertible transformation f.

Mathematical Formulation

The density of a realization x is given by the change of variables formula: p(X=x)=p(Z=f(x))detf(x)xp(X = x) = p(Z = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right| To sample from p(X), realizations z ~ p(Z) are mapped through the inverse transformation g = f⁻¹.
The log absolute determinant of the Jacobian (LADJ) is computed automatically through the transform’s call_and_ladj method, enabling efficient density evaluation.

Constructor

NormalizingFlow(transform, base)
transform
torch.distributions.Transform
required
An invertible transformation f that maps between the target and base spaces. Must implement inv() for the inverse transformation and call_and_ladj() for computing both the forward transformation and its log absolute determinant Jacobian.
base
torch.distributions.Distribution
required
The base distribution p(Z) in the latent space. Samples from this distribution are transformed to the target space. The distribution is automatically wrapped in Independent if needed to match event dimensions.

Properties

batch_shape
torch.Size
Returns the batch shape of the base distribution. This determines how many independent distributions are represented in a single object.
event_shape
torch.Size
Returns the event shape after applying the inverse transform to the base distribution’s event shape. This is the shape of a single sample.
has_rsample
bool
Always True. The distribution supports reparameterized sampling for gradient-based optimization.

Methods

log_prob

log_prob(x: Tensor) -> Tensor
Computes the log probability density at x using the change of variables formula. Arguments:
  • x - Sample points at which to evaluate the log probability
Returns:
  • Log probability density values with shape matching the batch shape
Implementation:
  1. Applies the forward transform: z = f(x)
  2. Computes the log absolute determinant Jacobian (LADJ)
  3. Returns: log p(z) + LADJ

rsample

rsample(shape: Size = ()) -> Tensor
Generates reparameterized samples using the inverse transform. Arguments:
  • shape - Desired sample shape (prepended to batch_shape + event_shape)
Returns:
  • Samples with shape shape + batch_shape + event_shape
Implementation:
  1. Samples z from the base distribution
  2. Applies the inverse transform: x = f⁻¹(z)
  3. Returns x

rsample_and_log_prob

rsample_and_log_prob(shape: Size = ()) -> tuple[Tensor, Tensor]
Efficiently generates samples and their log probabilities simultaneously. Arguments:
  • shape - Desired sample shape
Returns:
  • Tuple of (samples, log_probabilities)
Advantages:
  • More efficient than calling rsample() and log_prob() separately
  • Reuses intermediate computations from the inverse transform

expand

expand(batch_shape: Size, new: Distribution | None = None) -> Distribution
Creates a new distribution with expanded batch dimensions. Arguments:
  • batch_shape - Target batch shape
  • new - Optional instance to populate (internal use)
Returns:
  • Expanded distribution instance

Examples

Basic Usage

import torch
from torch.distributions import Gamma
from torch.distributions.transforms import ExpTransform
from zuko.distributions import NormalizingFlow

# Create a normalizing flow from Gamma to log-normal
transform = ExpTransform()
base = Gamma(2.0, 1.0)
flow = NormalizingFlow(transform, base)

# Sample from the distribution
x = flow.sample()
print(x)  # tensor(1.5157)

# Compute log probability
log_p = flow.log_prob(x)
print(log_p)  # tensor(-1.2345)

With Custom Transforms

from zuko.transforms import AffineTransform
from torch.distributions import Normal

# Define an affine transformation
shift = torch.tensor([1.0, 2.0])
scale = torch.tensor([0.5, 1.5])
transform = AffineTransform(shift, scale)

# Base distribution
base = Normal(torch.zeros(2), torch.ones(2))

# Create flow
flow = NormalizingFlow(transform, base)

# Sample multiple points
samples = flow.rsample((100,))
print(samples.shape)  # torch.Size([100, 2])

Composed Flows

from torch.distributions.transforms import ComposeTransform
from zuko.transforms import Spline

# Compose multiple transformations
transform = ComposeTransform([
    AffineTransform(torch.zeros(3), torch.ones(3)),
    Spline(bins=8, context=0)
])

base = Normal(torch.zeros(3), torch.ones(3))
flow = NormalizingFlow(transform, base)

# Efficient sampling with log probabilities
samples, log_probs = flow.rsample_and_log_prob((1000,))
print(samples.shape)     # torch.Size([1000, 3])
print(log_probs.shape)   # torch.Size([1000])

Batched Distributions

# Create a batch of flows with different base distributions
base = Normal(
    loc=torch.randn(5, 2),      # 5 different means
    scale=torch.ones(5, 2)
)

transform = ExpTransform()
flow = NormalizingFlow(transform, base)

print(flow.batch_shape)   # torch.Size([5])
print(flow.event_shape)   # torch.Size([2])

# Sample from all flows at once
samples = flow.sample((10,))
print(samples.shape)      # torch.Size([10, 5, 2])

References

Normalizing flows enable flexible density estimation by transforming simple base distributions into complex target distributions while maintaining tractable likelihoods.
  1. Tabak, E. G., & Turner, C. V. (2013). A Family of Non-parametric Density Estimation Algorithms. Communications on Pure and Applied Mathematics. Link
  2. Rezende, D., & Mohamed, S. (2015). Variational Inference with Normalizing Flows. ICML 2015. arXiv:1505.05770
  3. Papamakarios, G., et al. (2021). Normalizing Flows for Probabilistic Modeling and Inference. Journal of Machine Learning Research. arXiv:1912.02762

See Also

Build docs developers (and LLMs) love