In normalizing flows, transformations are invertible mappings that connect distributions. A transformation f : X → Y f: \mathcal{X} \to \mathcal{Y} f : X → Y must be:
Bijective : One-to-one and onto (invertible)
Differentiable : We can compute gradients
Tractable Jacobian : The determinant can be computed efficiently
These properties enable the change of variables formula:
p ( X = x ) = p ( Y = f ( x ) ) ∣ det ∂ f ( x ) ∂ x ∣ p(X = x) = p(Y = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right| p ( X = x ) = p ( Y = f ( x )) det ∂ x ∂ f ( x )
Zuko’s transformations extend PyTorch’s torch.distributions.transforms.Transform interface with additional functionality for efficient flow computation.
All transformations in Zuko implement these core methods:
from torch.distributions.transforms import Transform
import torch
class MyTransform ( Transform ):
def _call ( self , x : torch.Tensor) -> torch.Tensor:
"""Forward transformation: y = f(x)"""
pass
def _inverse ( self , y : torch.Tensor) -> torch.Tensor:
"""Inverse transformation: x = f^{-1}(y)"""
pass
def log_abs_det_jacobian ( self , x : torch.Tensor, y : torch.Tensor) -> torch.Tensor:
"""Log absolute determinant of Jacobian"""
pass
Enhanced Interface: call_and_ladj
Zuko adds a crucial optimization to PyTorch’s Transform class (in zuko/transforms.py:46-56):
def call_and_ladj ( self , x : Tensor) -> tuple[Tensor, Tensor]:
"""Returns both transformed value and log-abs-det-Jacobian."""
y = self . __call__ (x)
ladj = self .log_abs_det_jacobian(x, y)
return y, ladj
Many transformations can compute f ( x ) f(x) f ( x ) and log ∣ det J f ( x ) ∣ \log|\det J_f(x)| log ∣ det J f ( x ) ∣ simultaneously more efficiently than separately. This method enables that optimization.
For complex transformations like splines, computing the Jacobian determinant requires intermediate values from the forward pass. call_and_ladj avoids redundant computation.
Let’s examine a simple but illustrative transformation (from zuko/transforms.py:412-447):
class MonotonicAffineTransform ( Transform ):
r """Creates transformation f(x) = exp(a) * x + b.
Arguments:
shift: The shift term b, shape (*,)
scale: The unconstrained scale factor a, shape (*,)
slope: The minimum slope of the transformation
"""
domain = constraints.real
codomain = constraints.real
bijective = True
sign = + 1
def __init__ ( self , shift : Tensor, scale : Tensor, slope : float = 1e-3 ):
super (). __init__ ()
self .shift = shift
# Constrain scale to ensure minimum slope
self .log_scale = scale / ( 1 + abs (scale / math.log(slope)))
self .scale = self .log_scale.exp()
def _call ( self , x : Tensor) -> Tensor:
return x * self .scale + self .shift
def _inverse ( self , y : Tensor) -> Tensor:
return (y - self .shift) / self .scale
def log_abs_det_jacobian ( self , x : Tensor, y : Tensor) -> Tensor:
return self .log_scale.expand(x.shape)
Usage:
import torch
from zuko.transforms import MonotonicAffineTransform
# Create transformation
shift = torch.tensor([ 1.0 , - 0.5 ])
scale = torch.tensor([ 0.5 , 1.0 ])
transform = MonotonicAffineTransform(shift, scale)
# Apply forward
x = torch.randn( 3 , 2 )
y = transform(x)
# Apply inverse
x_reconstructed = transform.inv(y)
print (torch.allclose(x, x_reconstructed)) # True
# Compute Jacobian
ladj = transform.log_abs_det_jacobian(x, y)
Real flows stack multiple transformations. Zuko’s ComposedTransform (from zuko/transforms.py:59-161) handles this efficiently:
from zuko.transforms import ComposedTransform
from torch.distributions.transforms import TanhTransform, AffineTransform
# Compose: f(x) = f_2(f_1(f_0(x)))
transform = ComposedTransform(
AffineTransform( 0.0 , 2.0 ), # f_0: scale by 2
TanhTransform(), # f_1: squash to (-1, 1)
AffineTransform( 0.5 , 0.5 ), # f_2: map to (0, 1)
)
x = torch.randn( 100 )
y = transform(x)
# Inverse automatically reverses order
x_back = transform.inv(y)
Jacobian Computation
For composed transformations f = f n ∘ ⋯ ∘ f 0 f = f_n \circ \cdots \circ f_0 f = f n ∘ ⋯ ∘ f 0 , the log-determinant follows the chain rule:
log ∣ det J f ( x ) ∣ = ∑ i = 0 n log ∣ det J f i ( x i ) ∣ \log \left| \det J_f(x) \right| = \sum_{i=0}^n \log \left| \det J_{f_i}(x_i) \right| log ∣ det J f ( x ) ∣ = i = 0 ∑ n log ∣ det J f i ( x i ) ∣
Zuko’s implementation (from zuko/transforms.py:141-150):
def call_and_ladj ( self , x : Tensor) -> tuple[Tensor, Tensor]:
event_dim = self .domain_dim
acc = 0
for t in self .transforms:
x, ladj = t.call_and_ladj(x)
acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim)
event_dim += t.codomain.event_dim - t.domain.event_dim
return x, acc
When composing transformations, ensure domain/codomain compatibility. The codomain of f i f_i f i must match the domain of f i + 1 f_{i+1} f i + 1 .
Identity and Simple Mappings
from zuko.transforms import IdentityTransform, SignedPowerTransform
# Identity: f(x) = x (useful as placeholder)
identity = IdentityTransform()
# Signed power: f(x) = sign(x) * |x|^α
alpha = torch.tensor( 2.0 )
power = SignedPowerTransform(alpha)
from zuko.transforms import SoftclipTransform, CircularShiftTransform
# Softclip: maps R to [-B, B] smoothly
softclip = SoftclipTransform( bound = 5.0 )
# Circular shift: for periodic domains
shift = CircularShiftTransform( bound = 1.0 )
Rational quadratic splines (RQS) are among the most powerful transformations (from zuko/transforms.py:449-568):
from zuko.transforms import MonotonicRQSTransform
import torch
# Define spline knots
K = 8 # Number of bins
widths = torch.randn(K)
heights = torch.randn(K)
derivatives = torch.randn(K - 1 )
# Create transformation
rqs = MonotonicRQSTransform(
widths,
heights,
derivatives,
bound = 5.0 ,
slope = 1e-3
)
x = torch.linspace( - 4 , 4 , 100 )
y = rqs(x)
How Rational Quadratic Splines Work
RQS transformations piece together rational quadratic functions between knot points. Within each bin k k k : y = y 0 + ( y 1 − y 0 ) s z 2 + d 0 z ( 1 − z ) s + ( d 0 + d 1 − 2 s ) z ( 1 − z ) y = y_0 + (y_1 - y_0) \frac{s z^2 + d_0 z(1-z)}{s + (d_0 + d_1 - 2s)z(1-z)} y = y 0 + ( y 1 − y 0 ) s + ( d 0 + d 1 − 2 s ) z ( 1 − z ) s z 2 + d 0 z ( 1 − z ) where z = ( x − x 0 ) / ( x 1 − x 0 ) z = (x - x_0)/(x_1 - x_0) z = ( x − x 0 ) / ( x 1 − x 0 ) is the normalized position, s = ( y 1 − y 0 ) / ( x 1 − x 0 ) s = (y_1 - y_0)/(x_1 - x_0) s = ( y 1 − y 0 ) / ( x 1 − x 0 ) is the secant slope, and d 0 , d 1 d_0, d_1 d 0 , d 1 are knot derivatives. This formulation ensures:
Monotonicity (when derivatives are positive)
Smoothness at knot boundaries
Efficient inversion via quadratic formula
Coupling and Autoregressive
These enable scalable flow architectures:
from zuko.transforms import CouplingTransform, AutoregressiveTransform
import torch
# Coupling: split input, transform half conditioned on other half
mask = torch.tensor([ True , True , False , False ]) # Which dims are constant
def build_transform ( x_a ):
# x_a is the constant part, build transform for x_b
shift = torch.nn.Linear( 2 , 2 )(x_a)
return AffineTransform(shift, torch.ones_like(shift))
coupling = CouplingTransform(build_transform, mask)
# Autoregressive: each dimension depends on previous
def build_ar_transform ( x ):
# Build transformation conditioned on x
return SomeTransform(params_from(x))
ar = AutoregressiveTransform(build_ar_transform, passes = 5 )
Just like distributions, transformations can be lazy and context-dependent:
from zuko.lazy import LazyTransform
import torch.nn as nn
class ConditionalSpline ( LazyTransform ):
"""Spline transformation conditioned on context."""
def __init__ ( self , features : int , context : int , bins : int = 8 ):
super (). __init__ ()
self .net = nn.Sequential(
nn.Linear(context, 64 ),
nn.ReLU(),
nn.Linear( 64 , features * ( 3 * bins - 1 )),
)
self .features = features
self .bins = bins
def forward ( self , c : torch.Tensor | None = None ) -> Transform:
# Compute spline parameters from context
params = self .net(c).reshape( - 1 , self .features, 3 * self .bins - 1 )
widths = params[ ... , : self .bins]
heights = params[ ... , self .bins: 2 * self .bins]
derivatives = params[ ... , 2 * self .bins:]
return MonotonicRQSTransform(widths, heights, derivatives)
Usage:
# Create lazy transform
transform = ConditionalSpline( features = 2 , context = 5 )
# Get transformation for specific context
context = torch.randn( 32 , 5 )
t = transform(context)
# Now apply it
x = torch.randn( 32 , 2 )
y = t(x)
Lazy transformations enable neural transformations where parameters are computed by neural networks from context.
For simple cases without context:
from zuko.lazy import UnconditionalTransform
from torch.distributions.transforms import ExpTransform
# Wrap a standard transform
lazy_exp = UnconditionalTransform(ExpTransform)
# Call to get the transform (ignores context)
transform = lazy_exp()
x = torch.randn( 10 )
y = transform(x)
Continuous-time flows using neural ODEs (from zuko/transforms.py:1076-1180):
from zuko.transforms import FreeFormJacobianTransform
import torch.nn as nn
class VectorField ( nn . Module ):
def __init__ ( self , features : int ):
super (). __init__ ()
self .net = nn.Sequential(
nn.Linear(features, 64 ),
nn.Tanh(),
nn.Linear( 64 , features),
)
def forward ( self , t : torch.Tensor, x : torch.Tensor) -> torch.Tensor:
return self .net(x)
field = VectorField( features = 2 )
transform = FreeFormJacobianTransform(
f = field,
t0 = 0.0 ,
t1 = 1.0 ,
phi = list (field.parameters()),
atol = 1e-6 ,
rtol = 1e-5 ,
exact = True ,
)
x = torch.randn( 100 , 2 )
y, ladj = transform.call_and_ladj(x)
Neural ODE transformations are powerful but computationally expensive. Use them when expressiveness is critical and you have sufficient compute budget.
from zuko.transforms import LULinearTransform, RotationTransform
import torch
# LU decomposition for efficient inversion
LU = torch.randn( 3 , 3 )
lu_transform = LULinearTransform( LU )
# Rotation (orthogonal matrix)
A = torch.randn( 3 , 3 )
rotation = RotationTransform(A)
Always work in log-space for scale parameters and Jacobians. Zuko’s transformations use log_abs_det_jacobian rather than computing determinants directly to avoid numerical issues. # Good: work in log space
log_scale = some_network(x)
scale = log_scale.exp()
# Bad: can overflow/underflow
scale = some_network(x)
log_scale = scale.log()
Minimum Slope Constraints
Many transformations include a slope parameter (e.g., 1e-3) to ensure numerical stability by bounding the Jacobian determinant away from zero: rqs = MonotonicRQSTransform(
widths, heights, derivatives,
slope = 1e-3 # Ensures |det J| >= 1e-3
)
This prevents gradient explosion and improves training stability.
Composing for Expressiveness
Single transformations are often limited. Stack multiple transformations for greater expressiveness: transform = ComposedTransform(
MonotonicRQSTransform( ... ), # Flexible nonlinear
RotationTransform( ... ), # Mix dimensions
MonotonicRQSTransform( ... ), # Another nonlinear layer
)
Always verify your transformations are truly invertible: x = torch.randn( 100 , 2 )
y = transform(x)
x_reconstructed = transform.inv(y)
assert torch.allclose(x, x_reconstructed, atol = 1e-5 )
Transformation Use Case Jacobian Cost IdentityTransformPlaceholder O ( 1 ) O(1) O ( 1 ) AffineTransformLocation-scale O ( 1 ) O(1) O ( 1 ) MonotonicAffineTransformConstrained affine O ( 1 ) O(1) O ( 1 ) MonotonicRQSTransformFlexible 1D O ( K ) O(K) O ( K ) AutoregressiveTransformCoupling layers O ( D ) O(D) O ( D ) CouplingTransformScalable flows O ( D / 2 ) O(D/2) O ( D /2 ) PermutationTransformMixing O ( 1 ) O(1) O ( 1 ) RotationTransformOrthogonal mixing O ( 1 ) O(1) O ( 1 ) LULinearTransformLinear flows O ( D ) O(D) O ( D ) FreeFormJacobianTransformMaximum flexibility O ( D 2 ) O(D^2) O ( D 2 ) or O ( D ) O(D) O ( D ) *
*Depends on exact parameter
Next Steps
Flow Architectures See how transformations combine into complete flow models
Custom Flows Learn to build custom transformation architectures