Skip to main content

Overview

The Flow class is the primary way to construct normalizing flows in Zuko. It combines lazy transformations and a lazy base distribution to create a complete normalizing flow p(xc)=p0(f1(xc))detJf1(xc)p(x | c) = p_0(f^{-1}(x|c)) |\det J_{f^{-1}}(x|c)|.
A normalizing flow transforms a simple base distribution (like a standard normal) through a series of invertible transformations. The Flow class handles the composition of transformations and the change of variables formula automatically.

Flow

Creates a lazy normalizing flow.
class Flow(LazyDistribution)
Flow is a LazyDistribution that returns a NormalizingFlow distribution when called with a context.

Constructor

def __init__(
    transform: LazyTransform | Sequence[LazyTransform],
    base: LazyDistribution
)
transform
LazyTransform | Sequence[LazyTransform]
required
A lazy transformation or sequence of lazy transformations. If a sequence is provided, the transformations are automatically composed using LazyComposedTransform. The transformations are applied in order: the first transformation in the list is applied first.
base
LazyDistribution
required
A lazy distribution representing the base distribution p0(z)p_0(z). Typically an UnconditionalDistribution wrapping a simple distribution like a standard normal.

Methods

forward

def forward(c: Tensor | None = None) -> NormalizingFlow
Builds and returns a normalizing flow distribution.
c
Tensor | None
A context tensor. If provided, the flow conditions on this context. If None, the flow is unconditional.
Returns: A NormalizingFlow distribution object representing p(xc)p(x | c).
When a context c is provided with shape (..., context_features), the base distribution is automatically expanded to match the batch dimensions (...). This allows the flow to handle batched contexts correctly.

Examples

Basic Flow with MAF

import torch
from zuko.lazy import Flow, UnconditionalDistribution
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal

# Define the base distribution (standard normal)
base = UnconditionalDistribution(
    DiagNormal,
    loc=torch.zeros(3),
    scale=torch.ones(3),
    buffer=True
)

# Create a flow with a single MAF transformation
flow = Flow(
    transform=MaskedAutoregressiveTransform(
        features=3,
        context=5,
        hidden_features=(128, 128)
    ),
    base=base
)

# Use the flow
context = torch.randn(10, 5)  # Batch of 10 contexts
dist = flow(context)  # Get conditional distribution p(x|c)

# Sample from the flow
samples = dist.sample((100,))  # Sample 100 points for each context
print(samples.shape)  # torch.Size([100, 10, 3])

# Evaluate log probability
x = torch.randn(10, 3)
log_prob = dist.log_prob(x)
print(log_prob.shape)  # torch.Size([10])

Multi-Layer Flow with Mixed Transforms

This example shows how to build a flow with multiple transformations, including both conditional and unconditional transforms:
import torch
from zuko.lazy import Flow, UnconditionalDistribution, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal
from zuko.transforms import RotationTransform

# Define the base distribution
base = UnconditionalDistribution(
    DiagNormal,
    loc=torch.zeros(3),
    scale=torch.ones(3),
    buffer=True
)

# Create a flow with multiple transformations
flow = Flow(
    transform=[
        # First MAF layer (conditional)
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
        
        # Unconditional rotation
        UnconditionalTransform(RotationTransform, torch.randn(3, 3)),
        
        # Second MAF layer (conditional)
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
    ],
    base=base
)

# The transformations are automatically composed
context = torch.randn(10, 5)
dist = flow(context)
samples = dist.sample()
print(samples.shape)  # torch.Size([10, 3])
This is the pattern shown in the Zuko README. It demonstrates how to mix conditional transformations (like MaskedAutoregressiveTransform) with unconditional ones (like UnconditionalTransform(RotationTransform, ...)). The unconditional transformations are still parameterized and can be learned during training.

Training a Flow

import torch
import torch.optim as optim
from zuko.lazy import Flow, UnconditionalDistribution
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal

# Create a flow
flow = Flow(
    transform=MaskedAutoregressiveTransform(
        features=3,
        context=5,
        hidden_features=(128, 128, 128),
    ),
    base=UnconditionalDistribution(
        DiagNormal,
        loc=torch.zeros(3),
        scale=torch.ones(3),
        buffer=True,
    )
)

# Setup optimizer
optimizer = optim.Adam(flow.parameters(), lr=1e-3)

# Training loop
for epoch in range(100):
    for x, c in trainset:
        # x: data samples, shape (batch_size, 3)
        # c: context, shape (batch_size, 5)
        
        # Get conditional distribution
        dist = flow(c)
        
        # Compute negative log likelihood
        loss = -dist.log_prob(x).mean()
        
        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch}: loss = {loss.item():.4f}")

# After training, sample from the learned distribution
test_context = torch.randn(64, 5)
test_dist = flow(test_context)
generated_samples = test_dist.sample()
print(generated_samples.shape)  # torch.Size([64, 3])

Unconditional Flow

You can also create unconditional flows by using transformations that don’t depend on context:
import torch
from zuko.lazy import Flow, UnconditionalDistribution, UnconditionalTransform
from torch.distributions.transforms import AffineTransform
from zuko.distributions import DiagNormal

# Unconditional base
base = UnconditionalDistribution(
    DiagNormal,
    loc=torch.zeros(3),
    scale=torch.ones(3),
    buffer=True
)

# Unconditional affine transformation
transform = UnconditionalTransform(
    AffineTransform,
    loc=torch.zeros(3),
    scale=torch.ones(3),
    buffer=False  # Make it trainable
)

# Create unconditional flow
flow = Flow(transform=transform, base=base)

# Call without context
dist = flow()  # No context needed
samples = dist.sample((100,))
print(samples.shape)  # torch.Size([100, 3])

Using Pre-built Flow Classes

Zuko provides pre-built flow classes for common architectures. These are often more convenient than building flows manually:
import torch
import zuko

# Neural Spline Flow (NSF) with 3 features and 5 context features
flow = zuko.flows.NSF(
    features=3,
    context=5,
    transforms=3,  # Number of transformation layers
    hidden_features=[128, 128, 128]  # Hidden layer sizes
)

# Use it the same way
context = torch.randn(10, 5)
dist = flow(context)
samples = dist.sample()
log_prob = dist.log_prob(samples)
Pre-built flows like NSF, MAF, RealNVP, etc., are subclasses of Flow with pre-configured architectures. They’re great for quick experimentation, while custom Flow objects give you more control over the architecture.

Understanding Transform Composition

When you pass a list of transformations to Flow, they are automatically composed. The order matters:
# These transformations are applied in sequence:
flow = Flow(
    transform=[t1, t2, t3],  # Applied as: t3(t2(t1(x)))
    base=base
)

# During sampling: z ~ base() -> t1(z) -> t2(...) -> t3(...) -> x
# During evaluation: x -> t3^{-1}(x) -> t2^{-1}(...) -> t1^{-1}(...) -> z
The transformations are applied left-to-right during sampling (forward direction) and right-to-left during density evaluation (inverse direction).

Integration with PyTorch Training

Flow objects are standard nn.Module instances, so they integrate seamlessly with PyTorch:
import torch
import torch.nn as nn

# Flow is a module
assert isinstance(flow, nn.Module)

# Move to GPU
flow = flow.to('cuda')

# Get parameters
params = list(flow.parameters())

# Save and load
torch.save(flow.state_dict(), 'flow.pth')
flow.load_state_dict(torch.load('flow.pth'))

# Set to eval mode
flow.eval()
with torch.no_grad():
    dist = flow(context)
    samples = dist.sample()

Advanced: Custom Base Distributions

You can use any lazy distribution as the base, not just UnconditionalDistribution:
import torch
import torch.nn as nn
from zuko.lazy import Flow, LazyDistribution
from zuko.distributions import DiagNormal

class ContextDependentBase(LazyDistribution):
    """A base distribution whose parameters depend on context."""
    
    def __init__(self, features: int, context: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(context, 64),
            nn.ReLU(),
            nn.Linear(64, features * 2)  # Mean and log-scale
        )
        self.features = features
    
    def forward(self, c: torch.Tensor | None = None):
        if c is None:
            raise ValueError("Context required")
        
        params = self.net(c)
        mu, log_scale = params.chunk(2, dim=-1)
        scale = torch.exp(log_scale)
        
        return DiagNormal(mu, scale)

# Use custom base in a flow
flow = Flow(
    transform=my_transform,
    base=ContextDependentBase(features=3, context=5)
)

context = torch.randn(10, 5)
dist = flow(context)  # Base distribution now depends on context
When using a context-dependent base distribution, the base parameters are conditioned on the context before the transformations are applied. This is different from conditioning the transformations themselves. Consider whether you need context dependence in the base, the transforms, or both.

See Also

Build docs developers (and LLMs) love