The MonotonicRQSTransform implements a monotonic rational-quadratic spline transformation, providing a highly expressive and flexible transformation with guaranteed invertibility.
Rational quadratic splines (RQS) divide the domain into K bins and define a smooth, monotonic transformation within each bin using rational quadratic functions. The transformation is:
f(x)=y0+(y1−y0)s+(d0+d1−2s)z(1−z)sz2+d0z(1−z)
where:
- z=x1−x0x−x0 is the normalized position within the bin
- s=x1−x0y1−y0 is the secant slope
- d0,d1 are the derivatives at the bin boundaries
- (x0,y0) and (x1,y1) are the bin boundaries
The spline is defined over the interval [−B,B] and extends linearly outside this region.
Class Definition
class MonotonicRQSTransform(Transform)
Creates a monotonic rational-quadratic spline (RQS) transformation.
The unconstrained bin widths, with shape (∗,K). These are normalized using softmax to ensure they sum to the domain width.
The unconstrained bin heights, with shape (∗,K). These are normalized using softmax to ensure they sum to the codomain height.
The unconstrained knot derivatives, with shape (∗,K−1). These control the slope at bin boundaries and are constrained to be positive.
The spline’s (co)domain bound B. The spline is defined over [−B,B].
The minimum slope of the transformation. This ensures numerical stability and invertibility.
Properties
- Domain:
constraints.real
- Codomain:
constraints.real
- Bijective:
True
- Sign:
+1
Implementation Details
Parameter Constraints
The unconstrained parameters are processed to ensure monotonicity:
widths = widths / (1 + abs(2 * widths / math.log(slope)))
heights = heights / (1 + abs(2 * heights / math.log(slope)))
derivatives = derivatives / (1 + abs(derivatives / math.log(slope)))
widths = F.softmax(widths, dim=-1)
heights = F.softmax(heights, dim=-1)
derivatives = torch.exp(derivatives)
Bin Search
The transformation uses binary search to find the appropriate bin:
def searchsorted(seq: Tensor, value: Tensor) -> LongTensor:
return torch.sum(seq < value[..., None], dim=-1)
Usage Examples
import torch
import zuko
# Define spline parameters
batch_size = 32
K = 8 # Number of bins
widths = torch.randn(batch_size, K)
heights = torch.randn(batch_size, K)
derivatives = torch.randn(batch_size, K - 1)
# Create RQS transformation
transform = zuko.transforms.MonotonicRQSTransform(
widths=widths,
heights=heights,
derivatives=derivatives,
bound=5.0,
slope=1e-3
)
# Apply transformation
x = torch.randn(batch_size, 1)
y = transform(x)
# Inverse transformation
x_reconstructed = transform.inv(y)
# Compute log determinant
ladj = transform.log_abs_det_jacobian(x, y)
print(f"LADJ shape: {ladj.shape}") # [32, 1]
# Efficient forward + LADJ
y, ladj = transform.call_and_ladj(x)
RQS in Coupling Layer
import torch
import torch.nn as nn
import zuko
class RQSCouplingNet(nn.Module):
"""Network that produces RQS parameters for coupling."""
def __init__(self, in_features: int, out_features: int, bins: int = 8):
super().__init__()
self.bins = bins
# Output: widths, heights, derivatives for each output dimension
self.net = nn.Sequential(
nn.Linear(in_features, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, out_features * (bins * 2 + bins - 1)),
)
def forward(self, x_a):
batch_size = x_a.shape[0]
params = self.net(x_a) # [batch, out_features * (2K + K - 1)]
# Reshape to [batch, out_features, params_per_dim]
params = params.view(batch_size, -1, self.bins * 2 + self.bins - 1)
# Split into widths, heights, derivatives
widths = params[..., :self.bins]
heights = params[..., self.bins:2*self.bins]
derivatives = params[..., 2*self.bins:]
# Create independent RQS for each dimension
# Note: This creates a dependent transform wrapper
return zuko.transforms.DependentTransform(
zuko.transforms.MonotonicRQSTransform(
widths, heights, derivatives,
bound=5.0
),
reinterpreted=1
)
# Create coupling with RQS
features = 16
mask = torch.zeros(features, dtype=torch.bool)
mask[:features // 2] = True
net = RQSCouplingNet(
in_features=mask.sum(),
out_features=(~mask).sum(),
bins=8
)
transform = zuko.transforms.CouplingTransform(meta=net, mask=mask)
# Apply to data
x = torch.randn(64, features)
y, ladj = transform.call_and_ladj(x)
print(f"Output: {y.shape}, LADJ: {ladj.shape}") # [64, 16], [64]
Conditional RQS Flow
import torch
import torch.nn as nn
import zuko
class ConditionalRQSFlow(nn.Module):
"""Flow with RQS transformations conditioned on context."""
def __init__(self, features: int, context_dim: int, bins: int = 8, layers: int = 3):
super().__init__()
self.features = features
self.context_nets = nn.ModuleList()
self.transforms = []
for i in range(layers):
# Network to produce RQS parameters from context
net = nn.Sequential(
nn.Linear(context_dim, 128),
nn.ReLU(),
nn.Linear(128, features * (bins * 2 + bins - 1)),
)
self.context_nets.append(net)
# Mask for coupling
mask = torch.zeros(features, dtype=torch.bool)
if i % 2 == 0:
mask[:features // 2] = True
else:
mask[features // 2:] = True
self.transforms.append(mask)
def forward(self, x, context):
"""Apply conditional transformation."""
ladj_total = torch.zeros(x.shape[0], device=x.device)
for net, mask in zip(self.context_nets, self.transforms):
# Get RQS parameters from context
params = net(context)
params = params.view(x.shape[0], self.features, -1)
bins = (params.shape[-1] - 1) // 3
widths = params[..., :bins]
heights = params[..., bins:2*bins]
derivatives = params[..., 2*bins:]
# Split input
x_a = x[..., mask]
x_b = x[..., ~mask]
# Get parameters for transformed dimensions
widths_b = widths[..., ~mask, :]
heights_b = heights[..., ~mask, :]
derivatives_b = derivatives[..., ~mask, :]
# Apply RQS transformation
rqs = zuko.transforms.DependentTransform(
zuko.transforms.MonotonicRQSTransform(
widths_b, heights_b, derivatives_b
),
reinterpreted=1
)
y_b, ladj = rqs.call_and_ladj(x_b)
ladj_total = ladj_total + ladj
# Merge back
y = torch.empty_like(x)
y[..., mask] = x_a
y[..., ~mask] = y_b
x = y
return x, ladj_total
# Create and use conditional flow
flow = ConditionalRQSFlow(features=10, context_dim=5, bins=8, layers=3)
x = torch.randn(32, 10)
context = torch.randn(32, 5)
y, ladj = flow(x, context)
print(f"Output: {y.shape}, LADJ: {ladj.shape}") # [32, 10], [32]
import torch
import matplotlib.pyplot as plt
import zuko
# Create a simple RQS
K = 8
widths = torch.randn(1, K) * 0.5
heights = torch.randn(1, K) * 0.5
derivatives = torch.randn(1, K - 1) * 0.5
transform = zuko.transforms.MonotonicRQSTransform(
widths, heights, derivatives,
bound=3.0
)
# Evaluate on a grid
x = torch.linspace(-4, 4, 1000).unsqueeze(-1)
y = transform(x)
# Plot transformation
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(x.squeeze(), y.squeeze(), label='f(x)')
plt.plot(x.squeeze(), x.squeeze(), 'k--', alpha=0.3, label='Identity')
plt.axvline(-3, color='r', linestyle='--', alpha=0.5, label='Bounds')
plt.axvline(3, color='r', linestyle='--', alpha=0.5)
plt.xlabel('x')
plt.ylabel('y')
plt.title('RQS Transformation')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
ladj = transform.log_abs_det_jacobian(x, y)
plt.plot(x.squeeze(), ladj.squeeze())
plt.xlabel('x')
plt.ylabel('log |df/dx|')
plt.title('Log Absolute Jacobian Determinant')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Key Considerations
Number of Bins
The number of bins K controls expressiveness:
- K = 4-8: Good for most applications, efficient
- K = 16-32: Very expressive, more parameters
- K > 32: Diminishing returns, slower
Bound Selection
The bound B determines the spline’s domain:
- Should cover the typical range of your data
- Too small: Most data falls in linear region
- Too large: Wasted capacity, numerical issues
- Typical: 3.0 - 5.0 for normalized data
Minimum Slope
The slope parameter ensures numerical stability:
- Default
1e-3 works well for most cases
- Smaller values: More flexible but less stable
- Larger values: More stable but less expressive
Computational Cost
RQS transforms are more expensive than affine transforms:
- Forward/inverse: O(K) per evaluation (bin search + quadratic solve)
- Parameters: 3K−1 per dimension
- Trade-off: Expressiveness vs speed
Advantages
- Highly expressive: Can approximate complex monotonic functions
- Smooth: C1 continuous everywhere
- Guaranteed invertible: Monotonicity ensures bijectivity
- Bounded: Well-defined behavior over entire domain
- Efficient inverse: Closed-form solution using quadratic formula
References