Skip to main content

What are Transformations?

In normalizing flows, transformations are invertible mappings that connect distributions. A transformation f:XYf: \mathcal{X} \to \mathcal{Y} must be:
  1. Bijective: One-to-one and onto (invertible)
  2. Differentiable: We can compute gradients
  3. Tractable Jacobian: The determinant can be computed efficiently
These properties enable the change of variables formula: p(X=x)=p(Y=f(x))detf(x)xp(X = x) = p(Y = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right|
Zuko’s transformations extend PyTorch’s torch.distributions.transforms.Transform interface with additional functionality for efficient flow computation.

The Transform Interface

All transformations in Zuko implement these core methods:
from torch.distributions.transforms import Transform
import torch

class MyTransform(Transform):
    def _call(self, x: torch.Tensor) -> torch.Tensor:
        """Forward transformation: y = f(x)"""
        pass
    
    def _inverse(self, y: torch.Tensor) -> torch.Tensor:
        """Inverse transformation: x = f^{-1}(y)"""
        pass
    
    def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Log absolute determinant of Jacobian"""
        pass

Enhanced Interface: call_and_ladj

Zuko adds a crucial optimization to PyTorch’s Transform class (in zuko/transforms.py:46-56):
def call_and_ladj(self, x: Tensor) -> tuple[Tensor, Tensor]:
    """Returns both transformed value and log-abs-det-Jacobian."""
    y = self.__call__(x)
    ladj = self.log_abs_det_jacobian(x, y)
    return y, ladj
Many transformations can compute f(x)f(x) and logdetJf(x)\log|\det J_f(x)| simultaneously more efficiently than separately. This method enables that optimization.
For complex transformations like splines, computing the Jacobian determinant requires intermediate values from the forward pass. call_and_ladj avoids redundant computation.

Example: Monotonic Affine Transform

Let’s examine a simple but illustrative transformation (from zuko/transforms.py:412-447):
class MonotonicAffineTransform(Transform):
    r"""Creates transformation f(x) = exp(a) * x + b.
    
    Arguments:
        shift: The shift term b, shape (*,)
        scale: The unconstrained scale factor a, shape (*,)
        slope: The minimum slope of the transformation
    """
    
    domain = constraints.real
    codomain = constraints.real
    bijective = True
    sign = +1
    
    def __init__(self, shift: Tensor, scale: Tensor, slope: float = 1e-3):
        super().__init__()
        self.shift = shift
        # Constrain scale to ensure minimum slope
        self.log_scale = scale / (1 + abs(scale / math.log(slope)))
        self.scale = self.log_scale.exp()
    
    def _call(self, x: Tensor) -> Tensor:
        return x * self.scale + self.shift
    
    def _inverse(self, y: Tensor) -> Tensor:
        return (y - self.shift) / self.scale
    
    def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
        return self.log_scale.expand(x.shape)
Usage:
import torch
from zuko.transforms import MonotonicAffineTransform

# Create transformation
shift = torch.tensor([1.0, -0.5])
scale = torch.tensor([0.5, 1.0])
transform = MonotonicAffineTransform(shift, scale)

# Apply forward
x = torch.randn(3, 2)
y = transform(x)

# Apply inverse
x_reconstructed = transform.inv(y)
print(torch.allclose(x, x_reconstructed))  # True

# Compute Jacobian
ladj = transform.log_abs_det_jacobian(x, y)

Composing Transformations

Real flows stack multiple transformations. Zuko’s ComposedTransform (from zuko/transforms.py:59-161) handles this efficiently:
from zuko.transforms import ComposedTransform
from torch.distributions.transforms import TanhTransform, AffineTransform

# Compose: f(x) = f_2(f_1(f_0(x)))
transform = ComposedTransform(
    AffineTransform(0.0, 2.0),      # f_0: scale by 2
    TanhTransform(),                 # f_1: squash to (-1, 1)
    AffineTransform(0.5, 0.5),       # f_2: map to (0, 1)
)

x = torch.randn(100)
y = transform(x)

# Inverse automatically reverses order
x_back = transform.inv(y)

Jacobian Computation

For composed transformations f=fnf0f = f_n \circ \cdots \circ f_0, the log-determinant follows the chain rule: logdetJf(x)=i=0nlogdetJfi(xi)\log \left| \det J_f(x) \right| = \sum_{i=0}^n \log \left| \det J_{f_i}(x_i) \right| Zuko’s implementation (from zuko/transforms.py:141-150):
def call_and_ladj(self, x: Tensor) -> tuple[Tensor, Tensor]:
    event_dim = self.domain_dim
    acc = 0
    
    for t in self.transforms:
        x, ladj = t.call_and_ladj(x)
        acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim)
        event_dim += t.codomain.event_dim - t.domain.event_dim
    
    return x, acc
When composing transformations, ensure domain/codomain compatibility. The codomain of fif_i must match the domain of fi+1f_{i+1}.

Common Transformation Types

Identity and Simple Mappings

from zuko.transforms import IdentityTransform, SignedPowerTransform

# Identity: f(x) = x (useful as placeholder)
identity = IdentityTransform()

# Signed power: f(x) = sign(x) * |x|^α
alpha = torch.tensor(2.0)
power = SignedPowerTransform(alpha)

Bounded Transformations

from zuko.transforms import SoftclipTransform, CircularShiftTransform

# Softclip: maps R to [-B, B] smoothly
softclip = SoftclipTransform(bound=5.0)

# Circular shift: for periodic domains
shift = CircularShiftTransform(bound=1.0)

Monotonic Spline Transformations

Rational quadratic splines (RQS) are among the most powerful transformations (from zuko/transforms.py:449-568):
from zuko.transforms import MonotonicRQSTransform
import torch

# Define spline knots
K = 8  # Number of bins
widths = torch.randn(K)
heights = torch.randn(K)
derivatives = torch.randn(K - 1)

# Create transformation
rqs = MonotonicRQSTransform(
    widths,
    heights,
    derivatives,
    bound=5.0,
    slope=1e-3
)

x = torch.linspace(-4, 4, 100)
y = rqs(x)
RQS transformations piece together rational quadratic functions between knot points. Within each bin kk:y=y0+(y1y0)sz2+d0z(1z)s+(d0+d12s)z(1z)y = y_0 + (y_1 - y_0) \frac{s z^2 + d_0 z(1-z)}{s + (d_0 + d_1 - 2s)z(1-z)}where z=(xx0)/(x1x0)z = (x - x_0)/(x_1 - x_0) is the normalized position, s=(y1y0)/(x1x0)s = (y_1 - y_0)/(x_1 - x_0) is the secant slope, and d0,d1d_0, d_1 are knot derivatives.This formulation ensures:
  • Monotonicity (when derivatives are positive)
  • Smoothness at knot boundaries
  • Efficient inversion via quadratic formula

Coupling and Autoregressive

These enable scalable flow architectures:
from zuko.transforms import CouplingTransform, AutoregressiveTransform
import torch

# Coupling: split input, transform half conditioned on other half
mask = torch.tensor([True, True, False, False])  # Which dims are constant

def build_transform(x_a):
    # x_a is the constant part, build transform for x_b
    shift = torch.nn.Linear(2, 2)(x_a)
    return AffineTransform(shift, torch.ones_like(shift))

coupling = CouplingTransform(build_transform, mask)

# Autoregressive: each dimension depends on previous
def build_ar_transform(x):
    # Build transformation conditioned on x
    return SomeTransform(params_from(x))

ar = AutoregressiveTransform(build_ar_transform, passes=5)

LazyTransform Pattern

Just like distributions, transformations can be lazy and context-dependent:
from zuko.lazy import LazyTransform
import torch.nn as nn

class ConditionalSpline(LazyTransform):
    """Spline transformation conditioned on context."""
    
    def __init__(self, features: int, context: int, bins: int = 8):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(context, 64),
            nn.ReLU(),
            nn.Linear(64, features * (3 * bins - 1)),
        )
        self.features = features
        self.bins = bins
    
    def forward(self, c: torch.Tensor | None = None) -> Transform:
        # Compute spline parameters from context
        params = self.net(c).reshape(-1, self.features, 3 * self.bins - 1)
        
        widths = params[..., :self.bins]
        heights = params[..., self.bins:2*self.bins]
        derivatives = params[..., 2*self.bins:]
        
        return MonotonicRQSTransform(widths, heights, derivatives)
Usage:
# Create lazy transform
transform = ConditionalSpline(features=2, context=5)

# Get transformation for specific context
context = torch.randn(32, 5)
t = transform(context)

# Now apply it
x = torch.randn(32, 2)
y = t(x)
Lazy transformations enable neural transformations where parameters are computed by neural networks from context.

UnconditionalTransform

For simple cases without context:
from zuko.lazy import UnconditionalTransform
from torch.distributions.transforms import ExpTransform

# Wrap a standard transform
lazy_exp = UnconditionalTransform(ExpTransform)

# Call to get the transform (ignores context)
transform = lazy_exp()
x = torch.randn(10)
y = transform(x)

Specialized Transformations

Free-Form Jacobian (FFJORD)

Continuous-time flows using neural ODEs (from zuko/transforms.py:1076-1180):
from zuko.transforms import FreeFormJacobianTransform
import torch.nn as nn

class VectorField(nn.Module):
    def __init__(self, features: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(features, 64),
            nn.Tanh(),
            nn.Linear(64, features),
        )
    
    def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

field = VectorField(features=2)
transform = FreeFormJacobianTransform(
    f=field,
    t0=0.0,
    t1=1.0,
    phi=list(field.parameters()),
    atol=1e-6,
    rtol=1e-5,
    exact=True,
)

x = torch.randn(100, 2)
y, ladj = transform.call_and_ladj(x)
Neural ODE transformations are powerful but computationally expensive. Use them when expressiveness is critical and you have sufficient compute budget.

Linear Transformations

from zuko.transforms import LULinearTransform, RotationTransform
import torch

# LU decomposition for efficient inversion
LU = torch.randn(3, 3)
lu_transform = LULinearTransform(LU)

# Rotation (orthogonal matrix)
A = torch.randn(3, 3)
rotation = RotationTransform(A)

Transformation Best Practices

Always work in log-space for scale parameters and Jacobians. Zuko’s transformations use log_abs_det_jacobian rather than computing determinants directly to avoid numerical issues.
# Good: work in log space
log_scale = some_network(x)
scale = log_scale.exp()

# Bad: can overflow/underflow
scale = some_network(x)
log_scale = scale.log()
Many transformations include a slope parameter (e.g., 1e-3) to ensure numerical stability by bounding the Jacobian determinant away from zero:
rqs = MonotonicRQSTransform(
    widths, heights, derivatives,
    slope=1e-3  # Ensures |det J| >= 1e-3
)
This prevents gradient explosion and improves training stability.
Single transformations are often limited. Stack multiple transformations for greater expressiveness:
transform = ComposedTransform(
    MonotonicRQSTransform(...),  # Flexible nonlinear
    RotationTransform(...),       # Mix dimensions
    MonotonicRQSTransform(...),  # Another nonlinear layer
)
Always verify your transformations are truly invertible:
x = torch.randn(100, 2)
y = transform(x)
x_reconstructed = transform.inv(y)

assert torch.allclose(x, x_reconstructed, atol=1e-5)

Transformation Catalog

TransformationUse CaseJacobian Cost
IdentityTransformPlaceholderO(1)O(1)
AffineTransformLocation-scaleO(1)O(1)
MonotonicAffineTransformConstrained affineO(1)O(1)
MonotonicRQSTransformFlexible 1DO(K)O(K)
AutoregressiveTransformCoupling layersO(D)O(D)
CouplingTransformScalable flowsO(D/2)O(D/2)
PermutationTransformMixingO(1)O(1)
RotationTransformOrthogonal mixingO(1)O(1)
LULinearTransformLinear flowsO(D)O(D)
FreeFormJacobianTransformMaximum flexibilityO(D2)O(D^2) or O(D)O(D)*
*Depends on exact parameter

Next Steps

Flow Architectures

See how transformations combine into complete flow models

Custom Flows

Learn to build custom transformation architectures

Build docs developers (and LLMs) love