Polynomial transforms create monotonic bijections using polynomial functions. These transforms are smooth, expressive, and have well-defined theoretical properties.
Overview
Zuko provides three polynomial-based transformations:
SOSPolynomialTransform : Sum-of-squares polynomial transformation
BernsteinTransform : Bernstein polynomial transformation
BoundedBernsteinTransform : Bounded Bernstein polynomial with identity bounds
The sum-of-squares (SOS) polynomial transformation is defined as:
f ( x ) = ∫ 0 x 1 K ∑ i = 1 K ( 1 + ∑ j = 0 L a i , j u j ) 2 d u f(x) = \int_0^x \frac{1}{K} \sum_{i=1}^K \left( 1 + \sum_{j=0}^L a_{i,j} u^j \right)^2 du f ( x ) = ∫ 0 x K 1 i = 1 ∑ K ( 1 + j = 0 ∑ L a i , j u j ) 2 d u
The transformation is the integral of a sum of squared polynomials, ensuring monotonicity since the integrand is always positive.
Class Definition
class SOSPolynomialTransform ( UnconstrainedMonotonicTransform )
The polynomial coefficients a a a , with shape ( ∗ , K , L + 1 ) (*, K, L + 1) ( ∗ , K , L + 1 ) where K K K is the number of polynomials and L L L is the degree.
The minimum slope of the transformation, ensuring numerical stability.
Keyword arguments passed to UnconstrainedMonotonicTransform.
Properties
Domain : constraints.real
Codomain : constraints.real
Bijective : True
Sign : +1
Usage Example
import torch
import zuko
# Define polynomial parameters
batch_size = 32
K = 4 # Number of polynomials
L = 3 # Polynomial degree
a = torch.randn(batch_size, K, L + 1 )
# Create SOS polynomial transformation
transform = zuko.transforms.SOSPolynomialTransform(
a = a,
slope = 1e-3 ,
bound = 10.0 ,
eps = 1e-6
)
# Apply transformation
x = torch.randn(batch_size, 1 )
y = transform(x)
# Compute log determinant
ladj = transform.log_abs_det_jacobian(x, y)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 1], [32, 1]
# Inverse (uses bisection method)
x_reconstructed = transform.inv(y)
print ( f "Reconstruction error: { (x - x_reconstructed).abs().max() } " ) # Should be small
SOS in Neural Flows
import torch
import torch.nn as nn
import zuko
class SOSFlowLayer ( nn . Module ):
"""Flow layer using SOS polynomial transforms."""
def __init__ ( self , features : int , K : int = 4 , L : int = 3 ):
super (). __init__ ()
self .features = features
self .K = K
self .L = L
# Network to produce polynomial coefficients
self .net = nn.Sequential(
nn.Linear(features // 2 , 64 ),
nn.ReLU(),
nn.Linear( 64 , (features // 2 ) * K * (L + 1 )),
)
# Coupling mask
self .mask = torch.zeros(features, dtype = torch.bool)
self .mask[:features // 2 ] = True
def forward ( self , x ):
# Split input
x_a = x[ ... , self .mask]
x_b = x[ ... , ~ self .mask]
# Get polynomial coefficients from x_a
a = self .net(x_a)
a = a.view(x.shape[ 0 ], - 1 , self .K, self .L + 1 )
# Apply SOS transform to each dimension of x_b
y_b = torch.zeros_like(x_b)
ladj = torch.zeros(x.shape[ 0 ], device = x.device)
for i in range (x_b.shape[ - 1 ]):
sos = zuko.transforms.SOSPolynomialTransform(a[:, i])
y_b[ ... , i:i + 1 ], ladj_i = sos.call_and_ladj(x_b[ ... , i:i + 1 ])
ladj = ladj + ladj_i.squeeze( - 1 )
# Merge
y = torch.empty_like(x)
y[ ... , self .mask] = x_a
y[ ... , ~ self .mask] = y_b
return y, ladj
# Create and use SOS flow
layer = SOSFlowLayer( features = 10 , K = 4 , L = 3 )
x = torch.randn( 32 , 10 )
y, ladj = layer(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 10], [32]
The Bernstein polynomial transformation uses Bernstein basis polynomials:
f ( x ) = 1 M + 1 ∑ i = 0 M b i + 1 , M − i + 1 ( x + B 2 B ) θ i f(x) = \frac{1}{M+1} \sum_{i=0}^M b_{i+1,M-i+1}\left(\frac{x+B}{2B}\right) \theta_i f ( x ) = M + 1 1 i = 0 ∑ M b i + 1 , M − i + 1 ( 2 B x + B ) θ i
where b i , j b_{i,j} b i , j are Bernstein basis polynomials (Beta distributions). The transformation is defined over [ − B , B ] [-B, B] [ − B , B ] and linearly extrapolated outside this domain.
Class Definition
class BernsteinTransform ( MonotonicTransform )
The unconstrained polynomial coefficients θ \theta θ , with shape ( ∗ , M − 2 ) (*, M - 2) ( ∗ , M − 2 ) where M M M is the order.
The polynomial’s domain bound B B B . The spline is defined over [ − B , B ] [-B, B] [ − B , B ] .
Keyword arguments passed to MonotonicTransform.
Properties
Domain : constraints.real
Codomain : constraints.real
Bijective : True
Sign : +1
Usage Example
import torch
import zuko
# Define Bernstein coefficients
batch_size = 32
M = 10 # Order (number of basis polynomials)
theta = torch.randn(batch_size, M - 2 )
# Create Bernstein transformation
transform = zuko.transforms.BernsteinTransform(
theta = theta,
bound = 5.0 ,
eps = 1e-6
)
# Apply transformation
x = torch.randn(batch_size, 1 )
y = transform(x)
# Efficient forward + LADJ
y, ladj = transform.call_and_ladj(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 1], [32, 1]
# Inverse (uses bisection)
x_reconstructed = transform.inv(y)
print ( f "Reconstruction error: { (x - x_reconstructed).abs().max() } " )
import torch
import matplotlib.pyplot as plt
import zuko
# Create Bernstein transform
M = 10
theta = torch.randn( 1 , M - 2 ) * 0.5
transform = zuko.transforms.BernsteinTransform(theta, bound = 3.0 )
# Evaluate on grid
x = torch.linspace( - 4 , 4 , 1000 ).unsqueeze( - 1 )
y = transform(x)
# Plot
plt.figure( figsize = ( 12 , 5 ))
plt.subplot( 1 , 2 , 1 )
plt.plot(x.squeeze(), y.squeeze(), label = 'f(x)' , linewidth = 2 )
plt.plot(x.squeeze(), x.squeeze(), 'k--' , alpha = 0.3 , label = 'Identity' )
plt.axvline( - 3 , color = 'r' , linestyle = '--' , alpha = 0.5 , label = 'Bounds' )
plt.axvline( 3 , color = 'r' , linestyle = '--' , alpha = 0.5 )
plt.xlabel( 'x' , fontsize = 12 )
plt.ylabel( 'y' , fontsize = 12 )
plt.title( 'Bernstein Polynomial Transformation' , fontsize = 14 )
plt.legend()
plt.grid( True , alpha = 0.3 )
plt.subplot( 1 , 2 , 2 )
ladj = transform.log_abs_det_jacobian(x, y)
plt.plot(x.squeeze(), ladj.squeeze(), linewidth = 2 )
plt.axvline( - 3 , color = 'r' , linestyle = '--' , alpha = 0.5 )
plt.axvline( 3 , color = 'r' , linestyle = '--' , alpha = 0.5 )
plt.xlabel( 'x' , fontsize = 12 )
plt.ylabel( 'log |df/dx|' , fontsize = 12 )
plt.title( 'Log Determinant' , fontsize = 14 )
plt.grid( True , alpha = 0.3 )
plt.tight_layout()
plt.show()
Description
The BoundedBernsteinTransform is a specialized version of BernsteinTransform that ensures:
Domain and codomain both match [ − B , B ] [-B, B] [ − B , B ]
First derivative at bounds equals 1: f ′ ( − B ) = f ′ ( B ) = 1 f'(-B) = f'(B) = 1 f ′ ( − B ) = f ′ ( B ) = 1
Second derivative at bounds equals 0: f ′ ′ ( − B ) = f ′ ′ ( B ) = 0 f''(-B) = f''(B) = 0 f ′′ ( − B ) = f ′′ ( B ) = 0
These constraints ensure smooth transition to identity outside bounds, making it ideal for chaining in flows.
Class Definition
class BoundedBernsteinTransform ( BernsteinTransform )
The unconstrained polynomial coefficients θ \theta θ , with shape ( ∗ , M − 5 ) (*, M - 5) ( ∗ , M − 5 ) . Note the reduced dimension compared to BernsteinTransform due to boundary constraints.
Keyword arguments passed to BernsteinTransform.
Properties
Domain : constraints.real
Codomain : constraints.real
Bijective : True
Sign : +1
Usage Example
import torch
import zuko
# Define coefficients (fewer parameters due to constraints)
batch_size = 32
M = 10
theta = torch.randn(batch_size, M - 5 ) # Note: M - 5, not M - 2
# Create bounded Bernstein transformation
transform = zuko.transforms.BoundedBernsteinTransform(
theta = theta,
bound = 5.0 ,
eps = 1e-6
)
# Apply transformation
x = torch.randn(batch_size, 1 )
y, ladj = transform.call_and_ladj(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 1], [32, 1]
# Check boundary behavior
x_bound = torch.tensor([[ 5.0 ], [ - 5.0 ]])
theta_test = torch.randn( 1 , M - 5 )
transform_test = zuko.transforms.BoundedBernsteinTransform(theta_test, bound = 5.0 )
y_bound = transform_test(x_bound)
print ( f "At x= { x_bound.squeeze().tolist() } : y= { y_bound.squeeze().tolist() } " )
# Should be close to [5.0, -5.0]
Bounded Bernstein in Flow
import torch
import torch.nn as nn
import zuko
class BoundedBernsteinFlow ( nn . Module ):
"""Flow using bounded Bernstein transforms."""
def __init__ ( self , features : int , order : int = 10 , layers : int = 3 ):
super (). __init__ ()
self .transforms = []
for i in range (layers):
# Coupling mask
mask = torch.zeros(features, dtype = torch.bool)
if i % 2 == 0 :
mask[:features // 2 ] = True
else :
mask[features // 2 :] = True
# Network for coefficients
n_const = mask.sum().item()
n_transform = ( ~ mask).sum().item()
class BernsteinNet ( nn . Module ):
def __init__ ( self , in_dim , out_dim , order ):
super (). __init__ ()
self .net = nn.Sequential(
nn.Linear(in_dim, 64 ),
nn.ReLU(),
nn.Linear( 64 , out_dim * (order - 5 )),
)
self .order = order
self .out_dim = out_dim
def forward ( self , x_a ):
theta = self .net(x_a)
theta = theta.view(x_a.shape[ 0 ], self .out_dim, self .order - 5 )
# Create independent transforms for each dimension
return zuko.transforms.DependentTransform(
zuko.transforms.BoundedBernsteinTransform(
theta, bound = 5.0
),
reinterpreted = 1
)
net = BernsteinNet(n_const, n_transform, order)
coupling = zuko.transforms.CouplingTransform( meta = net, mask = mask)
self .transforms.append(coupling)
self .flow = zuko.transforms.ComposedTransform( * self .transforms)
def forward ( self , x ):
return self .flow.call_and_ladj(x)
# Create and use flow
flow = BoundedBernsteinFlow( features = 8 , order = 10 , layers = 3 )
x = torch.randn( 32 , 8 )
y, ladj = flow(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 8], [32]
Transform Parameters Computation Best For SOSPolynomialTransform K × ( L + 1 ) K \times (L+1) K × ( L + 1 ) Quadrature Smooth, guaranteed positive derivative BernsteinTransform M − 2 M - 2 M − 2 Beta distributions Interpretable, smooth extrapolation BoundedBernsteinTransform M − 5 M - 5 M − 5 Beta distributions Chaining in flows, identity bounds
Key Considerations
Polynomial Order
Higher order polynomials are more expressive but:
Require more parameters
May be harder to optimize
Can have numerical instabilities
Recommended : Start with moderate orders (M=10-20 for Bernstein, L=3-5 for SOS)
Bound Selection
The bound parameter determines the transformation’s effective range:
Should cover typical data range
Larger bounds: More extrapolation (linear region)
Smaller bounds: More polynomial region, but less coverage
Inverse Computation
Both Bernstein transforms use bisection for inverse:
Slower than closed-form solutions (e.g., RQS)
Accuracy controlled by eps parameter
Trade-off between speed and precision
Advantages and Disadvantages
Advantages
Theoretical guarantees : Monotonicity ensured by construction
Smooth : Infinitely differentiable (SOS) or C 1 C^1 C 1 (Bernstein)
Interpretable : Polynomial coefficients have clear meaning
Flexible : Can approximate many functions
Disadvantages
Inverse cost : Bisection method is slower than analytical inverse
Extrapolation : Linear outside bounds (Bernstein) or requires large bounds (SOS)
Parameter efficiency : May need many parameters for complex functions
References
Sum-of-Squares Polynomial Flow
Deep Transformation Models
Bernstein-Polynomial Normalizing Flows
Arpogaus, M., Voss, M., Sick, B., & Nigge, M. (2022). Short-Term Density Forecasting of Low-Voltage Load using Bernstein-Polynomial Normalizing Flows.
https://arxiv.org/abs/2204.13939