Skip to main content

What are Normalizing Flows?

Normalizing flows are a powerful class of generative models that transform a simple base distribution (like a standard normal) into a complex target distribution through a series of invertible transformations. The key insight is that we can express complex probability distributions by warping simple ones.

Mathematical Foundation

Given a random variable XX and a transformation ff, normalizing flows use the change of variables formula to compute the probability density: p(X=x)=p(Z=f(x))detf(x)xp(X = x) = p(Z = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right| where:
  • p(Z)p(Z) is a simple base distribution (e.g., standard normal)
  • ff is an invertible transformation
  • The determinant term accounts for how the transformation warps space
The absolute value of the Jacobian determinant ensures the density remains non-negative, as it measures how volumes change under the transformation.

How Flows Enable Density Estimation

Normalizing flows enable efficient density estimation through two key operations:
  1. Sampling: Draw zp(Z)z \sim p(Z) from the base distribution and map through the inverse: x=f1(z)x = f^{-1}(z)
  2. Density evaluation: For any point xx, compute p(x)p(x) using the change of variables formula
This bidirectional property makes flows uniquely powerful for generative modeling, as they support both:
  • Fast sampling (forward direction)
  • Exact likelihood computation (inverse direction)

The NormalizingFlow Class

Zuko implements normalizing flows through the NormalizingFlow distribution class, which wraps PyTorch’s Distribution interface.
from zuko.distributions import NormalizingFlow
from torch.distributions import Normal
from torch.distributions.transforms import ExpTransform
import torch

# Create a flow: log-normal distribution
base = Normal(0.0, 1.0)
transform = ExpTransform()
flow = NormalizingFlow(transform, base)

# Sample from the flow
samples = flow.sample((1000,))

# Evaluate log probability
x = torch.tensor([1.5, 2.0, 0.5])
log_prob = flow.log_prob(x)

Architecture

The NormalizingFlow class (defined in zuko/distributions.py:39-139) connects a transformation with a base distribution:
class NormalizingFlow(Distribution):
    def __init__(self, transform: Transform, base: Distribution) -> None:
        super().__init__()
        
        # Handle event dimension alignment
        reinterpreted = transform.codomain.event_dim - len(base.event_shape)
        if reinterpreted > 0:
            base = Independent(base, reinterpreted)
        
        self.transform = transform
        self.base = base
The NormalizingFlow class automatically handles dimension alignment between the transformation’s codomain and the base distribution’s event shape.

Log Probability Computation

The log probability implements the change of variables formula:
def log_prob(self, x: Tensor) -> Tensor:
    z, ladj = self.transform.call_and_ladj(x)
    return self.base.log_prob(z) + ladj
where:
  • z = f(x) is the transformed value
  • ladj is the log absolute determinant of the Jacobian
  • The sum combines base log probability with the volume correction

Sampling

Sampling uses the inverse transformation:
def rsample(self, shape: Size = ()) -> Tensor:
    if self.base.has_rsample:
        z = self.base.rsample(shape)
    else:
        z = self.base.sample(shape)
    
    return self.transform.inv(z)
The method uses rsample (reparameterized sampling) when available to enable gradient-based optimization through the sampling process.

Efficient Sampling with Log Probability

For training, you often need both samples and their log probabilities. The rsample_and_log_prob method computes both efficiently:
def rsample_and_log_prob(self, shape: Size = ()) -> tuple[Tensor, Tensor]:
    z = self.base.rsample(shape)
    x, ladj = self.transform.inv.call_and_ladj(z)
    return x, self.base.log_prob(z) - ladj
Notice the sign change: when going from ZZ to XX via g=f1g = f^{-1}, the log probability becomes: logp(X=x)=logp(Z=g(x))logdetg(x)x\log p(X = x) = \log p(Z = g(x)) - \log \left| \det \frac{\partial g(x)}{\partial x} \right|

Practical Example

Here’s a complete example building a normalizing flow for 2D data:
import torch
from torch.distributions import Normal
from torch.distributions.transforms import AffineTransform
from zuko.distributions import NormalizingFlow, DiagNormal

# Define base distribution (2D standard normal)
base = DiagNormal(torch.zeros(2), torch.ones(2))

# Define transformation (affine)
scale = torch.tensor([[2.0, 0.0], [0.5, 1.5]])
shift = torch.tensor([1.0, -0.5])
transform = AffineTransform(shift, scale)

# Create flow
flow = NormalizingFlow(transform, base)

# Generate samples
samples = flow.sample((1000,))  # Shape: (1000, 2)

# Evaluate density
log_density = flow.log_prob(samples)
print(f"Mean log density: {log_density.mean():.4f}")

Composing Multiple Transformations

Real-world flows typically compose multiple transformations. Zuko’s ComposedTransform enables this:
from zuko.transforms import ComposedTransform
from torch.distributions.transforms import TanhTransform, AffineTransform

# Stack multiple transformations
transforms = ComposedTransform(
    AffineTransform(0.0, 2.0),  # Scale
    TanhTransform(),             # Squash to [-1, 1]
    AffineTransform(0.5, 0.5),   # Shift and scale to [0, 1]
)

flow = NormalizingFlow(transforms, Normal(0.0, 1.0))
When composing transformations, ensure domains and codomains align properly. The output domain of transformation fif_i must match the input domain of fi+1f_{i+1}.

References

The implementation follows these foundational papers:

Next Steps

Lazy Distributions

Learn how Zuko makes flows trainable with PyTorch

Transformations

Explore the transformation building blocks

Build docs developers (and LLMs) love