This guide shows you how to create custom normalizing flow architectures by composing transformations, using different base distributions, and building complex conditional models.
Understanding Flow Components
A normalizing flow in Zuko consists of two main components:
- Transformations: Bijective functions that map between spaces
- Base Distribution: The starting distribution (usually Gaussian)
from zuko.flows import Flow
from zuko.lazy import UnconditionalDistribution, UnconditionalTransform
flow = Flow(
transform=[...], # List of transformations
base=... # Base distribution
)
Building Custom Flows
Simple Custom Flow
Here’s a basic example combining affine and rotation transformations:
import torch
import zuko
from zuko.flows import Flow, MaskedAutoregressiveTransform
from zuko.lazy import UnconditionalDistribution, UnconditionalTransform
from zuko.transforms import RotationTransform
from zuko.distributions import DiagNormal
flow = Flow(
transform=[
MaskedAutoregressiveTransform(
features=3,
context=5,
hidden_features=(64, 64)
),
UnconditionalTransform(
RotationTransform,
torch.randn(3, 3) # Random initialization
),
MaskedAutoregressiveTransform(
features=3,
context=5,
hidden_features=(64, 64)
),
],
base=UnconditionalDistribution(
DiagNormal,
loc=torch.zeros(3),
scale=torch.ones(3),
buffer=True # Not trainable
)
)
This example is similar to the custom flow shown in the README. The transformations are applied in sequence: f_3 ∘ rotation ∘ f_1.
The most flexible transformation for building expressive flows:
from zuko.flows import MaskedAutoregressiveTransform
from zuko.transforms import MonotonicAffineTransform
transform = MaskedAutoregressiveTransform(
features=5,
context=8,
univariate=MonotonicAffineTransform, # Element-wise transformation
hidden_features=(128, 128, 128), # Hyper-network architecture
activation=torch.nn.ReLU # Activation function
)
Faster but less expressive than autoregressive:
from zuko.flows import GeneralCouplingTransform
from zuko.transforms import MonotonicRQSTransform
transform = GeneralCouplingTransform(
features=5,
context=8,
univariate=MonotonicRQSTransform, # Rational-quadratic spline
shapes=([8], [8], [7]), # Parameters for 8 bins
hidden_features=(256, 256)
)
For highly flexible element-wise transformations:
from zuko.transforms import MonotonicRQSTransform
# Used as univariate transform in MAT or coupling
transform = MaskedAutoregressiveTransform(
features=3,
context=0,
univariate=MonotonicRQSTransform,
shapes=([16], [16], [15]), # 16 bins for higher capacity
hidden_features=(128, 128)
)
Add non-trainable transformations:
from zuko.transforms import AffineTransform, SigmoidTransform
flow = Flow(
transform=[
# Map [0, 255] to (0, 1)
UnconditionalTransform(
AffineTransform,
torch.tensor(1/512), # loc
torch.tensor(1/256), # scale
buffer=True # Not trainable
),
# Map (0, 1) to (-inf, inf)
UnconditionalTransform(
lambda: SigmoidTransform().inv
),
# ... more transforms
],
base=...
)
from zuko.transforms import RotationTransform
flow = Flow(
transform=[
MaskedAutoregressiveTransform(...),
UnconditionalTransform(
RotationTransform,
torch.randn(5, 5) # Trainable rotation matrix
),
MaskedAutoregressiveTransform(...),
],
base=...
)
Adding rotation transforms between autoregressive layers can improve expressivity by mixing features.
Custom Base Distributions
Standard Gaussian
The most common choice:
from zuko.distributions import DiagNormal
base = UnconditionalDistribution(
DiagNormal,
loc=torch.zeros(features),
scale=torch.ones(features),
buffer=True
)
Useful for bounded domains:
from zuko.distributions import BoxUniform
base = UnconditionalDistribution(
BoxUniform,
torch.full([features], -3.0), # lower bound
torch.full([features], +3.0), # upper bound
buffer=True
)
Learnable Base Distribution
# Learnable mean and variance
base = UnconditionalDistribution(
DiagNormal,
loc=torch.randn(features), # Trainable
scale=torch.ones(features), # Trainable
buffer=False # Make trainable
)
Advanced Architectures
Multi-Scale Architecture
For high-dimensional data, split features across transformations:
from zuko.flows import Flow, MaskedAutoregressiveTransform
from zuko.transforms import PermutationTransform
features = 64
flow = Flow(
transform=[
# First scale
MaskedAutoregressiveTransform(features, 0, hidden_features=(256,)),
UnconditionalTransform(PermutationTransform, torch.randperm(features)),
# Second scale
MaskedAutoregressiveTransform(features, 0, hidden_features=(256,)),
UnconditionalTransform(PermutationTransform, torch.randperm(features)),
# Third scale
MaskedAutoregressiveTransform(features, 0, hidden_features=(256,)),
],
base=UnconditionalDistribution(
DiagNormal,
torch.zeros(features),
torch.ones(features),
buffer=True
)
)
Residual Flows
Combine coupling and autoregressive transforms:
from zuko.flows import GeneralCouplingTransform
flow = Flow(
transform=[
GeneralCouplingTransform(features=5, context=0, hidden_features=(128,)),
MaskedAutoregressiveTransform(features=5, context=0, hidden_features=(128,)),
GeneralCouplingTransform(features=5, context=0, hidden_features=(128,)),
MaskedAutoregressiveTransform(features=5, context=0, hidden_features=(128,)),
],
base=...
)
Conditional Flows
Context-Dependent Flows
All transformations can accept context:
flow = Flow(
transform=[
MaskedAutoregressiveTransform(
features=3,
context=5, # Context dimension
hidden_features=(128, 128)
),
],
base=UnconditionalDistribution(...) # Base ignores context
)
# Use with context
c = torch.randn(batch_size, 5)
dist = flow(c)
x = dist.sample()
Conditional Base Distribution
For conditional base distributions, create a custom LazyDistribution:
import torch.nn as nn
from torch.distributions import Independent, Normal
from zuko.lazy import LazyDistribution
class ConditionalBase(LazyDistribution):
def __init__(self, features: int, context: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(context, 64),
nn.ReLU(),
nn.Linear(64, 2 * features)
)
def forward(self, c):
params = self.net(c)
mu, log_sigma = params.chunk(2, dim=-1)
return Independent(Normal(mu, log_sigma.exp()), 1)
flow = Flow(
transform=[...],
base=ConditionalBase(features=3, context=5)
)
Easily create inverse transformations:
transform = MaskedAutoregressiveTransform(features=5, context=0)
flow = Flow(
transform=[
transform,
transform.inv, # Inverse transformation
transform,
],
base=...
)
Be careful with inverse transforms in autoregressive flows. The inverse of MAF becomes IAF (Inverse Autoregressive Flow), which has different computational properties.
Complete Custom Example
Here’s a sophisticated custom flow combining multiple techniques:
import torch
import zuko
from zuko.flows import Flow, MaskedAutoregressiveTransform, GeneralCouplingTransform
from zuko.lazy import UnconditionalDistribution, UnconditionalTransform
from zuko.transforms import (
AffineTransform,
MonotonicRQSTransform,
RotationTransform,
SigmoidTransform,
)
from zuko.distributions import BoxUniform
flow = Flow(
transform=[
# Preprocessing: [0, 255] -> R
UnconditionalTransform(
AffineTransform,
torch.tensor(1/512),
torch.tensor(1/256),
buffer=True
),
UnconditionalTransform(lambda: SigmoidTransform().inv),
# First autoregressive block with splines
MaskedAutoregressiveTransform(
features=5,
context=8,
univariate=MonotonicRQSTransform,
shapes=([8], [8], [7]), # 8 bins
hidden_features=(128, 128),
passes=5 # Fully autoregressive
),
# Rotation for mixing
UnconditionalTransform(RotationTransform, torch.randn(5, 5)),
# Coupling transform (faster than autoregressive)
GeneralCouplingTransform(
features=5,
context=8,
univariate=MonotonicRQSTransform,
shapes=([8], [8], [7]),
hidden_features=(256, 256),
activation=torch.nn.ELU
).inv, # Use inverse
],
base=UnconditionalDistribution(
BoxUniform,
torch.full([5], -3.0),
torch.full([5], +3.0),
buffer=True
)
)
# Train as usual
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
Tips for Custom Flows
Start Simple: Begin with pre-built flows like NSF or MAF, then customize only when needed.
Alternate Orders: For multiple autoregressive transforms, alternate the ordering for better mixing:transform=[
MaskedAutoregressiveTransform(...), # Forward order
MaskedAutoregressiveTransform(...), # Reverse order (automatic)
MaskedAutoregressiveTransform(...), # Forward order
]
Splines vs Affine: Use spline transforms (MonotonicRQSTransform) for more expressivity, but affine transforms (MonotonicAffineTransform) for faster training.
Next Steps