Skip to main content
Polynomial transforms create monotonic bijections using polynomial functions. These transforms are smooth, expressive, and have well-defined theoretical properties.

Overview

Zuko provides three polynomial-based transformations:
  1. SOSPolynomialTransform: Sum-of-squares polynomial transformation
  2. BernsteinTransform: Bernstein polynomial transformation
  3. BoundedBernsteinTransform: Bounded Bernstein polynomial with identity bounds

SOSPolynomialTransform

Mathematical Formulation

The sum-of-squares (SOS) polynomial transformation is defined as: f(x)=0x1Ki=1K(1+j=0Lai,juj)2duf(x) = \int_0^x \frac{1}{K} \sum_{i=1}^K \left( 1 + \sum_{j=0}^L a_{i,j} u^j \right)^2 du The transformation is the integral of a sum of squared polynomials, ensuring monotonicity since the integrand is always positive.

Class Definition

class SOSPolynomialTransform(UnconstrainedMonotonicTransform)
a
Tensor
The polynomial coefficients aa, with shape (,K,L+1)(*, K, L + 1) where KK is the number of polynomials and LL is the degree.
slope
float
default:"1e-3"
The minimum slope of the transformation, ensuring numerical stability.
kwargs
Keyword arguments passed to UnconstrainedMonotonicTransform.

Properties

  • Domain: constraints.real
  • Codomain: constraints.real
  • Bijective: True
  • Sign: +1

Usage Example

import torch
import zuko

# Define polynomial parameters
batch_size = 32
K = 4  # Number of polynomials
L = 3  # Polynomial degree

a = torch.randn(batch_size, K, L + 1)

# Create SOS polynomial transformation
transform = zuko.transforms.SOSPolynomialTransform(
    a=a,
    slope=1e-3,
    bound=10.0,
    eps=1e-6
)

# Apply transformation
x = torch.randn(batch_size, 1)
y = transform(x)

# Compute log determinant
ladj = transform.log_abs_det_jacobian(x, y)
print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 1], [32, 1]

# Inverse (uses bisection method)
x_reconstructed = transform.inv(y)
print(f"Reconstruction error: {(x - x_reconstructed).abs().max()}")  # Should be small

SOS in Neural Flows

import torch
import torch.nn as nn
import zuko

class SOSFlowLayer(nn.Module):
    """Flow layer using SOS polynomial transforms."""
    def __init__(self, features: int, K: int = 4, L: int = 3):
        super().__init__()
        self.features = features
        self.K = K
        self.L = L
        
        # Network to produce polynomial coefficients
        self.net = nn.Sequential(
            nn.Linear(features // 2, 64),
            nn.ReLU(),
            nn.Linear(64, (features // 2) * K * (L + 1)),
        )
        
        # Coupling mask
        self.mask = torch.zeros(features, dtype=torch.bool)
        self.mask[:features // 2] = True
    
    def forward(self, x):
        # Split input
        x_a = x[..., self.mask]
        x_b = x[..., ~self.mask]
        
        # Get polynomial coefficients from x_a
        a = self.net(x_a)
        a = a.view(x.shape[0], -1, self.K, self.L + 1)
        
        # Apply SOS transform to each dimension of x_b
        y_b = torch.zeros_like(x_b)
        ladj = torch.zeros(x.shape[0], device=x.device)
        
        for i in range(x_b.shape[-1]):
            sos = zuko.transforms.SOSPolynomialTransform(a[:, i])
            y_b[..., i:i+1], ladj_i = sos.call_and_ladj(x_b[..., i:i+1])
            ladj = ladj + ladj_i.squeeze(-1)
        
        # Merge
        y = torch.empty_like(x)
        y[..., self.mask] = x_a
        y[..., ~self.mask] = y_b
        
        return y, ladj

# Create and use SOS flow
layer = SOSFlowLayer(features=10, K=4, L=3)
x = torch.randn(32, 10)
y, ladj = layer(x)
print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 10], [32]

BernsteinTransform

Mathematical Formulation

The Bernstein polynomial transformation uses Bernstein basis polynomials: f(x)=1M+1i=0Mbi+1,Mi+1(x+B2B)θif(x) = \frac{1}{M+1} \sum_{i=0}^M b_{i+1,M-i+1}\left(\frac{x+B}{2B}\right) \theta_i where bi,jb_{i,j} are Bernstein basis polynomials (Beta distributions). The transformation is defined over [B,B][-B, B] and linearly extrapolated outside this domain.

Class Definition

class BernsteinTransform(MonotonicTransform)
theta
Tensor
The unconstrained polynomial coefficients θ\theta, with shape (,M2)(*, M - 2) where MM is the order.
bound
float
default:"5.0"
The polynomial’s domain bound BB. The spline is defined over [B,B][-B, B].
kwargs
Keyword arguments passed to MonotonicTransform.

Properties

  • Domain: constraints.real
  • Codomain: constraints.real
  • Bijective: True
  • Sign: +1

Usage Example

import torch
import zuko

# Define Bernstein coefficients
batch_size = 32
M = 10  # Order (number of basis polynomials)

theta = torch.randn(batch_size, M - 2)

# Create Bernstein transformation
transform = zuko.transforms.BernsteinTransform(
    theta=theta,
    bound=5.0,
    eps=1e-6
)

# Apply transformation
x = torch.randn(batch_size, 1)
y = transform(x)

# Efficient forward + LADJ
y, ladj = transform.call_and_ladj(x)
print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 1], [32, 1]

# Inverse (uses bisection)
x_reconstructed = transform.inv(y)
print(f"Reconstruction error: {(x - x_reconstructed).abs().max()}")

Visualizing Bernstein Transformation

import torch
import matplotlib.pyplot as plt
import zuko

# Create Bernstein transform
M = 10
theta = torch.randn(1, M - 2) * 0.5
transform = zuko.transforms.BernsteinTransform(theta, bound=3.0)

# Evaluate on grid
x = torch.linspace(-4, 4, 1000).unsqueeze(-1)
y = transform(x)

# Plot
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(x.squeeze(), y.squeeze(), label='f(x)', linewidth=2)
plt.plot(x.squeeze(), x.squeeze(), 'k--', alpha=0.3, label='Identity')
plt.axvline(-3, color='r', linestyle='--', alpha=0.5, label='Bounds')
plt.axvline(3, color='r', linestyle='--', alpha=0.5)
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title('Bernstein Polynomial Transformation', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
ladj = transform.log_abs_det_jacobian(x, y)
plt.plot(x.squeeze(), ladj.squeeze(), linewidth=2)
plt.axvline(-3, color='r', linestyle='--', alpha=0.5)
plt.axvline(3, color='r', linestyle='--', alpha=0.5)
plt.xlabel('x', fontsize=12)
plt.ylabel('log |df/dx|', fontsize=12)
plt.title('Log Determinant', fontsize=14)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

BoundedBernsteinTransform

Description

The BoundedBernsteinTransform is a specialized version of BernsteinTransform that ensures:
  1. Domain and codomain both match [B,B][-B, B]
  2. First derivative at bounds equals 1: f(B)=f(B)=1f'(-B) = f'(B) = 1
  3. Second derivative at bounds equals 0: f(B)=f(B)=0f''(-B) = f''(B) = 0
These constraints ensure smooth transition to identity outside bounds, making it ideal for chaining in flows.

Class Definition

class BoundedBernsteinTransform(BernsteinTransform)
theta
Tensor
The unconstrained polynomial coefficients θ\theta, with shape (,M5)(*, M - 5). Note the reduced dimension compared to BernsteinTransform due to boundary constraints.
kwargs
Keyword arguments passed to BernsteinTransform.

Properties

  • Domain: constraints.real
  • Codomain: constraints.real
  • Bijective: True
  • Sign: +1

Usage Example

import torch
import zuko

# Define coefficients (fewer parameters due to constraints)
batch_size = 32
M = 10

theta = torch.randn(batch_size, M - 5)  # Note: M - 5, not M - 2

# Create bounded Bernstein transformation
transform = zuko.transforms.BoundedBernsteinTransform(
    theta=theta,
    bound=5.0,
    eps=1e-6
)

# Apply transformation
x = torch.randn(batch_size, 1)
y, ladj = transform.call_and_ladj(x)

print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 1], [32, 1]

# Check boundary behavior
x_bound = torch.tensor([[5.0], [-5.0]])
theta_test = torch.randn(1, M - 5)
transform_test = zuko.transforms.BoundedBernsteinTransform(theta_test, bound=5.0)

y_bound = transform_test(x_bound)
print(f"At x={x_bound.squeeze().tolist()}: y={y_bound.squeeze().tolist()}")
# Should be close to [5.0, -5.0]

Bounded Bernstein in Flow

import torch
import torch.nn as nn
import zuko

class BoundedBernsteinFlow(nn.Module):
    """Flow using bounded Bernstein transforms."""
    def __init__(self, features: int, order: int = 10, layers: int = 3):
        super().__init__()
        self.transforms = []
        
        for i in range(layers):
            # Coupling mask
            mask = torch.zeros(features, dtype=torch.bool)
            if i % 2 == 0:
                mask[:features // 2] = True
            else:
                mask[features // 2:] = True
            
            # Network for coefficients
            n_const = mask.sum().item()
            n_transform = (~mask).sum().item()
            
            class BernsteinNet(nn.Module):
                def __init__(self, in_dim, out_dim, order):
                    super().__init__()
                    self.net = nn.Sequential(
                        nn.Linear(in_dim, 64),
                        nn.ReLU(),
                        nn.Linear(64, out_dim * (order - 5)),
                    )
                    self.order = order
                    self.out_dim = out_dim
                
                def forward(self, x_a):
                    theta = self.net(x_a)
                    theta = theta.view(x_a.shape[0], self.out_dim, self.order - 5)
                    # Create independent transforms for each dimension
                    return zuko.transforms.DependentTransform(
                        zuko.transforms.BoundedBernsteinTransform(
                            theta, bound=5.0
                        ),
                        reinterpreted=1
                    )
            
            net = BernsteinNet(n_const, n_transform, order)
            coupling = zuko.transforms.CouplingTransform(meta=net, mask=mask)
            self.transforms.append(coupling)
        
        self.flow = zuko.transforms.ComposedTransform(*self.transforms)
    
    def forward(self, x):
        return self.flow.call_and_ladj(x)

# Create and use flow
flow = BoundedBernsteinFlow(features=8, order=10, layers=3)
x = torch.randn(32, 8)
y, ladj = flow(x)
print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 8], [32]

Comparison of Polynomial Transforms

TransformParametersComputationBest For
SOSPolynomialTransformK×(L+1)K \times (L+1)QuadratureSmooth, guaranteed positive derivative
BernsteinTransformM2M - 2Beta distributionsInterpretable, smooth extrapolation
BoundedBernsteinTransformM5M - 5Beta distributionsChaining in flows, identity bounds

Key Considerations

Polynomial Order

Higher order polynomials are more expressive but:
  • Require more parameters
  • May be harder to optimize
  • Can have numerical instabilities
Recommended: Start with moderate orders (M=10-20 for Bernstein, L=3-5 for SOS)

Bound Selection

The bound parameter determines the transformation’s effective range:
  • Should cover typical data range
  • Larger bounds: More extrapolation (linear region)
  • Smaller bounds: More polynomial region, but less coverage

Inverse Computation

Both Bernstein transforms use bisection for inverse:
  • Slower than closed-form solutions (e.g., RQS)
  • Accuracy controlled by eps parameter
  • Trade-off between speed and precision

Advantages and Disadvantages

Advantages

  1. Theoretical guarantees: Monotonicity ensured by construction
  2. Smooth: Infinitely differentiable (SOS) or C1C^1 (Bernstein)
  3. Interpretable: Polynomial coefficients have clear meaning
  4. Flexible: Can approximate many functions

Disadvantages

  1. Inverse cost: Bisection method is slower than analytical inverse
  2. Extrapolation: Linear outside bounds (Bernstein) or requires large bounds (SOS)
  3. Parameter efficiency: May need many parameters for complex functions

References

Jaini, P., Selby, K. A., & Yu, Y. (2019). Sum-of-Squares Polynomial Flow.
https://arxiv.org/abs/1905.02325
Sick, B., Hothorn, T., & Dürr, O. (2020). Deep transformation models: Tackling complex regression problems with neural network based transformation models.
https://arxiv.org/abs/2004.00464
Arpogaus, M., Voss, M., Sick, B., & Nigge, M. (2022). Short-Term Density Forecasting of Low-Voltage Load using Bernstein-Polynomial Normalizing Flows.
https://arxiv.org/abs/2204.13939

Build docs developers (and LLMs) love