Skip to main content
The zuko.transforms module provides a collection of parameterizable transformations that serve as the building blocks for normalizing flows. All transforms are compatible with PyTorch’s distribution framework.

Overview

Transforms in Zuko are bijective mappings that can compute:
  • Forward transformation: y = f(x)
  • Inverse transformation: x = f^{-1}(y)
  • Log absolute determinant of the Jacobian: log |det J_f(x)|
All transforms inherit from torch.distributions.Transform and can be composed, inverted, and combined with distributions.

Transform Categories

Autoregressive Transforms

Transforms that apply autoregressive conditioning, where each output dimension depends on previous dimensions.

Coupling Transforms

Transforms that split inputs and apply transformations to subsets conditioned on others.

Spline Transforms

Monotonic spline-based transformations for flexible, continuous mappings.

Polynomial Transforms

Transformations based on polynomial functions.

Neural Transforms

Transforms constructed from neural networks and learned functions.

Utility Transforms

Basic transformations for common operations.

Composition Transforms

Transforms for combining and modifying other transforms.

Basic Usage

import torch
import zuko

# Create a simple affine transformation
shift = torch.tensor([1.0, 2.0])
scale = torch.tensor([0.5, 0.5])
transform = zuko.transforms.MonotonicAffineTransform(shift, scale)

# Forward transformation
x = torch.randn(10, 2)
y = transform(x)

# Inverse transformation
x_reconstructed = transform.inv(y)

# Log determinant of Jacobian
ladj = transform.log_abs_det_jacobian(x, y)

Composing Transforms

import torch
import zuko

# Create multiple transforms
t1 = zuko.transforms.AdditiveTransform(torch.tensor([1.0, 0.0]))
t2 = zuko.transforms.PermutationTransform(torch.tensor([1, 0]))
t3 = zuko.transforms.MonotonicAffineTransform(
    torch.tensor([0.0, 0.0]),
    torch.tensor([1.0, 1.0])
)

# Compose transforms
transform = zuko.transforms.ComposedTransform(t1, t2, t3)

# Use the composed transformation
x = torch.randn(10, 2)
y = transform(x)

Next Steps

Autoregressive

Learn about autoregressive transformations

Coupling

Explore coupling transformations

Spline

Discover spline-based transformations

Polynomial

Work with polynomial transformations

Build docs developers (and LLMs) love