Skip to main content
Neural transforms use neural networks to construct flexible, learnable transformations while maintaining important properties like monotonicity and invertibility.

Overview

Zuko provides several neural transformation approaches:
  1. UnconstrainedMonotonicTransform: Monotonic transformation via integration of positive functions
  2. MonotonicTransform: Wrapper for arbitrary monotonic functions with numerical inverse
  3. GaussianizationTransform: Gaussianization via learned affine CDFs
  4. FreeFormJacobianTransform: Continuous normalizing flows via ODEs (FFJORD)

UnconstrainedMonotonicTransform

Mathematical Formulation

The UnconstrainedMonotonicTransform creates a monotonic transformation by integrating a positive function: f(x)=0xg(u)duf(x) = \int_0^x g(u) \, du where g(x)>0g(x) > 0 is any positive function (e.g., neural network with positive output). The integral is computed using Gauss-Legendre quadrature. The Jacobian determinant is simply: dfdx=g(x)\frac{df}{dx} = g(x)

Class Definition

class UnconstrainedMonotonicTransform(MonotonicTransform)
g
Callable[[Tensor], Tensor]
default:"None"
A positive univariate function gg. If None, self.g is used instead. The function should return strictly positive values.
n
int
default:"32"
The number of points for Gauss-Legendre quadrature. Higher values increase accuracy but are slower.
kwargs
Keyword arguments passed to MonotonicTransform (e.g., bound, eps, phi).

Properties

  • Domain: constraints.real
  • Codomain: constraints.real
  • Bijective: True
  • Sign: +1

Usage Example

import torch
import torch.nn as nn
import zuko

class PositiveNet(nn.Module):
    """Neural network with positive output."""
    def __init__(self, context_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(context_dim + 1, 64),  # +1 for x input
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Softplus(),  # Ensure positive output
        )
        self.context = None
    
    def set_context(self, context):
        self.context = context
    
    def forward(self, x):
        # Concatenate x with context
        if self.context is not None:
            x_with_context = torch.cat([x, self.context.expand(x.shape[0], -1)], dim=-1)
        else:
            x_with_context = x
        return self.net(x_with_context) + 0.1  # Minimum positive value

# Create neural network
context_dim = 5
net = PositiveNet(context_dim)

# Set context
context = torch.randn(1, context_dim)
net.set_context(context)

# Create transformation
transform = zuko.transforms.UnconstrainedMonotonicTransform(
    g=net,
    n=32,  # Quadrature points
    bound=10.0,
    eps=1e-6,
    phi=(list(net.parameters()),)  # For inverse gradient computation
)

# Apply transformation
x = torch.randn(32, 1)
y, ladj = transform.call_and_ladj(x)

print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 1], [32, 1]

# Inverse (uses bisection)
x_reconstructed = transform.inv(y)
print(f"Reconstruction error: {(x - x_reconstructed).abs().max()}")

Integration-Based Coupling

import torch
import torch.nn as nn
import zuko

class IntegrationCouplingNet(nn.Module):
    """Network producing positive functions for integration."""
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
        )
        # Separate heads for each output dimension
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(128 + 1, 32),  # +1 for x value
                nn.ReLU(),
                nn.Linear(32, 1),
                nn.Softplus()
            )
            for _ in range(out_features)
        ])
    
    def forward(self, x_a):
        batch_size = x_a.shape[0]
        features = self.net(x_a)  # [batch, 128]
        
        # Create a function for each output dimension
        def g(x_b):
            # x_b has shape [batch, 1] for a single dimension
            combined = torch.cat([features, x_b], dim=-1)  # [batch, 129]
            outputs = []
            for head in self.heads:
                outputs.append(head(combined))
            return torch.cat(outputs, dim=-1)  # [batch, out_features]
        
        # Create transforms for each dimension
        transforms = []
        for i in range(len(self.heads)):
            def g_i(x, idx=i):
                return g(x)[:, idx:idx+1]
            
            transforms.append(
                zuko.transforms.UnconstrainedMonotonicTransform(
                    g=g_i,
                    n=16,
                    phi=list(self.parameters())
                )
            )
        
        # Compose transforms for all dimensions
        class MultiDimTransform:
            def __init__(self, transforms):
                self.transforms = transforms
            
            def __call__(self, x):
                outputs = []
                for i, t in enumerate(self.transforms):
                    outputs.append(t(x[:, i:i+1]))
                return torch.cat(outputs, dim=-1)
            
            def call_and_ladj(self, x):
                outputs = []
                ladjs = []
                for i, t in enumerate(self.transforms):
                    y_i, ladj_i = t.call_and_ladj(x[:, i:i+1])
                    outputs.append(y_i)
                    ladjs.append(ladj_i)
                return torch.cat(outputs, dim=-1), torch.cat(ladjs, dim=-1).sum(dim=-1)
        
        return MultiDimTransform(transforms)

# Create coupling flow with integration
features = 10
mask = torch.zeros(features, dtype=torch.bool)
mask[:features // 2] = True

net = IntegrationCouplingNet(
    in_features=mask.sum(),
    out_features=(~mask).sum()
)

transform = zuko.transforms.CouplingTransform(meta=net, mask=mask)

x = torch.randn(16, features)
y, ladj = transform.call_and_ladj(x)
print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [16, 10], [16]

MonotonicTransform

Description

The MonotonicTransform wraps any monotonic univariate function, providing automatic inverse computation via bisection.

Class Definition

class MonotonicTransform(Transform)
f
Callable[[Tensor], Tensor]
default:"None"
A monotonic univariate function fϕ(x)f_\phi(x). If None, self.f is used instead.
phi
Iterable[Tensor]
default:"()"
The parameters ϕ\phi of fϕf_\phi. Providing parameters is required to make the inverse transformation trainable.
bound
float
default:"10.0"
The domain bound BB. Used for bisection search interval [B,B][-B, B].
eps
float
default:"1e-6"
The absolute tolerance for the inverse transformation.

Usage Example

import torch
import torch.nn as nn
import zuko

class MonotonicNet(nn.Module):
    """Neural network with monotonic activation."""
    def __init__(self, hidden_dim: int = 64):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(1, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, 1),
        ])
        
    def forward(self, x):
        h = torch.relu(self.layers[0](x))
        h = torch.relu(self.layers[1](h))
        # Add x to ensure monotonicity (residual connection)
        return self.layers[2](h) + x

# Create monotonic neural network
net = MonotonicNet()

# Wrap in MonotonicTransform
transform = zuko.transforms.MonotonicTransform(
    f=net,
    phi=list(net.parameters()),
    bound=10.0,
    eps=1e-6
)

# Apply transformation
x = torch.randn(32, 1, requires_grad=True)
y, ladj = transform.call_and_ladj(x)

# Inverse
x_reconstructed = transform.inv(y)
print(f"Reconstruction error: {(x - x_reconstructed).abs().max()}")  # Should be < eps

# Gradients work through inverse
loss = y.sum()
loss.backward()
print(f"Gradients computed: {x.grad is not None}")  # True

GaussianizationTransform

Mathematical Formulation

The Gaussianization transformation maps data to a Gaussian distribution: f(x)=Φ1(1Ki=1KΦ(exp(ai)x+bi))f(x) = \Phi^{-1}\left( \frac{1}{K} \sum_{i=1}^K \Phi(\exp(a_i) x + b_i) \right) where Φ\Phi is the CDF of the standard normal distribution N(0,1)\mathcal{N}(0, 1).

Class Definition

class GaussianizationTransform(MonotonicTransform)
shift
Tensor
The shift terms bb, with shape (,K)(*, K).
scale
Tensor
The unconstrained scale factors aa, with shape (,K)(*, K).
kwargs
Keyword arguments passed to MonotonicTransform.

Usage Example

import torch
import torch.nn as nn
import zuko
import matplotlib.pyplot as plt

# Create Gaussianization transform
K = 8  # Number of components
shift = torch.randn(1, K)
scale = torch.randn(1, K) * 0.5

transform = zuko.transforms.GaussianizationTransform(
    shift=shift,
    scale=scale,
    bound=10.0,
    eps=1e-6
)

# Apply to non-Gaussian data
x = torch.cat([torch.randn(500, 1) * 0.5 - 2, torch.randn(500, 1) * 0.3 + 2])  # Bimodal
y = transform(x)

# Visualize Gaussianization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(x.detach().numpy(), bins=50, alpha=0.7, density=True)
axes[0].set_title('Original Distribution')
axes[0].set_xlabel('x')
axes[0].set_ylabel('Density')

axes[1].hist(y.detach().numpy(), bins=50, alpha=0.7, density=True)
axes[1].set_title('Gaussianized Distribution')
axes[1].set_xlabel('y')
axes[1].set_ylabel('Density')

# Overlay standard normal
import numpy as np
from scipy import stats
z = np.linspace(-4, 4, 100)
axes[1].plot(z, stats.norm.pdf(z), 'r-', lw=2, label='N(0,1)')
axes[1].legend()

plt.tight_layout()
plt.show()

FreeFormJacobianTransform

Mathematical Formulation

The Free-Form Jacobian (FFJORD) transformation uses continuous normalizing flows based on ODEs: x(t1)=x0+t0t1fϕ(t,x(t))dtx(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(t, x(t)) \, dt The log-determinant is computed by augmenting the ODE with a trace equation: dlogdetJdt=tr(fϕx)\frac{d \log |\det J|}{dt} = \text{tr}\left( \frac{\partial f_\phi}{\partial x} \right)

Class Definition

class FreeFormJacobianTransform(Transform)
f
Callable[[Tensor, Tensor], Tensor]
A system of first-order ODEs fϕ(t,x)f_\phi(t, x). Takes time tt and state xx as inputs.
t0
float | Tensor
default:"0.0"
The initial integration time t0t_0.
t1
float | Tensor
default:"1.0"
The final integration time t1t_1.
phi
Iterable[Tensor]
default:"()"
The parameters ϕ\phi of fϕf_\phi.
atol
float
default:"1e-6"
The absolute integration tolerance.
rtol
float
default:"1e-5"
The relative integration tolerance.
exact
bool
default:"True"
Whether to compute exact log-determinant (via full Jacobian) or use stochastic estimate (via Hutchinson’s trace estimator).

Properties

  • Domain: constraints.real_vector
  • Codomain: constraints.real_vector
  • Bijective: True

Usage Example

import torch
import torch.nn as nn
import zuko

class ODENet(nn.Module):
    """Neural ODE for FFJORD."""
    def __init__(self, features: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(features + 1, 64),  # +1 for time
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, features),
        )
    
    def forward(self, t, x):
        # Concatenate time with state
        t_vec = t * torch.ones(x.shape[0], 1, device=x.device)
        tx = torch.cat([t_vec, x], dim=-1)
        return self.net(tx)

# Create FFJORD transformation
features = 5
net = ODENet(features)

transform = zuko.transforms.FreeFormJacobianTransform(
    f=net,
    t0=0.0,
    t1=1.0,
    phi=list(net.parameters()),
    atol=1e-5,
    rtol=1e-4,
    exact=False  # Use stochastic trace estimator for efficiency
)

# Apply transformation
x = torch.randn(32, features)
y, ladj = transform.call_and_ladj(x)

print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 5], [32]

# Inverse (integrates backwards)
x_reconstructed = transform.inv(y)
print(f"Reconstruction error: {(x - x_reconstructed).abs().max()}")

FFJORD Training Example

import torch
import torch.nn as nn
import torch.optim as optim
import zuko

class CNF(nn.Module):
    """Continuous Normalizing Flow."""
    def __init__(self, features: int):
        super().__init__()
        self.features = features
        self.ode_net = ODENet(features)
        self.transform = zuko.transforms.FreeFormJacobianTransform(
            f=self.ode_net,
            t0=0.0,
            t1=1.0,
            phi=list(self.ode_net.parameters()),
            exact=False
        )
        # Base distribution
        self.register_buffer('base_loc', torch.zeros(features))
        self.register_buffer('base_scale', torch.ones(features))
    
    def forward(self, x):
        """Compute log probability."""
        z, ladj = self.transform.call_and_ladj(x)
        # Log probability in base space
        base_dist = torch.distributions.Normal(self.base_loc, self.base_scale)
        log_prob = base_dist.log_prob(z).sum(dim=-1) + ladj
        return log_prob
    
    def sample(self, n_samples):
        """Generate samples."""
        base_dist = torch.distributions.Normal(self.base_loc, self.base_scale)
        z = base_dist.sample((n_samples,))
        x = self.transform.inv(z)
        return x

# Create and train CNF
model = CNF(features=2)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training data (e.g., two moons)
from sklearn.datasets import make_moons
X_train, _ = make_moons(n_samples=1000, noise=0.05)
X_train = torch.FloatTensor(X_train)

# Training loop
for epoch in range(100):
    optimizer.zero_grad()
    log_prob = model(X_train)
    loss = -log_prob.mean()
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Generate samples
with torch.no_grad():
    samples = model.sample(1000)
    print(f"Generated samples: {samples.shape}")  # [1000, 2]

Comparison of Neural Transforms

TransformComputationInverseBest For
UnconstrainedMonotonicQuadratureBisectionFlexible monotonic functions
MonotonicUser functionBisectionWrapping existing functions
GaussianizationCDF compositionBisectionNormalization tasks
FreeFormJacobianODE solverReverse ODEMaximum flexibility, CNF

Key Considerations

Computational Cost

  • Integration-based: Quadrature adds overhead but ensures monotonicity
  • FFJORD: ODE solving is expensive, especially with exact trace
  • Inverse: Bisection methods slower than closed-form

Accuracy vs Speed

  • Quadrature points (n): More = accurate but slower
  • ODE tolerances (atol, rtol): Smaller = accurate but slower
  • Bisection tolerance (eps): Smaller = accurate but slower

When to Use Neural Transforms

Use neural transforms when:
  • Maximum flexibility is needed
  • Data has complex, unknown structure
  • Other transforms are too restrictive
Avoid when:
  • Speed is critical (use splines or polynomials)
  • Interpretability is important
  • Training data is limited

References

Chen, S. S., & Gopinath, R. A. (2000). Gaussianization.
https://papers.nips.cc/paper/1856-gaussianization
Grathwohl, W., Chen, R. T., Bettencourt, J., Sutskever, I., & Duvenaud, D. (2018). FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models.
https://arxiv.org/abs/1810.01367
Chen, R. T., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural Ordinary Differential Equations.
https://arxiv.org/abs/1806.07366

Build docs developers (and LLMs) love