Skip to main content
The MonotonicRQSTransform implements a monotonic rational-quadratic spline transformation, providing a highly expressive and flexible transformation with guaranteed invertibility.

Mathematical Formulation

Rational quadratic splines (RQS) divide the domain into KK bins and define a smooth, monotonic transformation within each bin using rational quadratic functions. The transformation is: f(x)=y0+(y1y0)sz2+d0z(1z)s+(d0+d12s)z(1z)f(x) = y_0 + (y_1 - y_0) \frac{s z^2 + d_0 z (1 - z)}{s + (d_0 + d_1 - 2s) z (1 - z)} where:
  • z=xx0x1x0z = \frac{x - x_0}{x_1 - x_0} is the normalized position within the bin
  • s=y1y0x1x0s = \frac{y_1 - y_0}{x_1 - x_0} is the secant slope
  • d0,d1d_0, d_1 are the derivatives at the bin boundaries
  • (x0,y0)(x_0, y_0) and (x1,y1)(x_1, y_1) are the bin boundaries
The spline is defined over the interval [B,B][-B, B] and extends linearly outside this region.

Class Definition

class MonotonicRQSTransform(Transform)
Creates a monotonic rational-quadratic spline (RQS) transformation.
widths
Tensor
The unconstrained bin widths, with shape (,K)(*, K). These are normalized using softmax to ensure they sum to the domain width.
heights
Tensor
The unconstrained bin heights, with shape (,K)(*, K). These are normalized using softmax to ensure they sum to the codomain height.
derivatives
Tensor
The unconstrained knot derivatives, with shape (,K1)(*, K - 1). These control the slope at bin boundaries and are constrained to be positive.
bound
float
default:"5.0"
The spline’s (co)domain bound BB. The spline is defined over [B,B][-B, B].
slope
float
default:"1e-3"
The minimum slope of the transformation. This ensures numerical stability and invertibility.

Properties

  • Domain: constraints.real
  • Codomain: constraints.real
  • Bijective: True
  • Sign: +1

Implementation Details

Parameter Constraints

The unconstrained parameters are processed to ensure monotonicity:
widths = widths / (1 + abs(2 * widths / math.log(slope)))
heights = heights / (1 + abs(2 * heights / math.log(slope)))
derivatives = derivatives / (1 + abs(derivatives / math.log(slope)))

widths = F.softmax(widths, dim=-1)
heights = F.softmax(heights, dim=-1)
derivatives = torch.exp(derivatives)
The transformation uses binary search to find the appropriate bin:
def searchsorted(seq: Tensor, value: Tensor) -> LongTensor:
    return torch.sum(seq < value[..., None], dim=-1)

Usage Examples

Basic RQS Transform

import torch
import zuko

# Define spline parameters
batch_size = 32
K = 8  # Number of bins

widths = torch.randn(batch_size, K)
heights = torch.randn(batch_size, K)
derivatives = torch.randn(batch_size, K - 1)

# Create RQS transformation
transform = zuko.transforms.MonotonicRQSTransform(
    widths=widths,
    heights=heights,
    derivatives=derivatives,
    bound=5.0,
    slope=1e-3
)

# Apply transformation
x = torch.randn(batch_size, 1)
y = transform(x)

# Inverse transformation
x_reconstructed = transform.inv(y)

# Compute log determinant
ladj = transform.log_abs_det_jacobian(x, y)
print(f"LADJ shape: {ladj.shape}")  # [32, 1]

# Efficient forward + LADJ
y, ladj = transform.call_and_ladj(x)

RQS in Coupling Layer

import torch
import torch.nn as nn
import zuko

class RQSCouplingNet(nn.Module):
    """Network that produces RQS parameters for coupling."""
    def __init__(self, in_features: int, out_features: int, bins: int = 8):
        super().__init__()
        self.bins = bins
        # Output: widths, heights, derivatives for each output dimension
        self.net = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_features * (bins * 2 + bins - 1)),
        )
        
    def forward(self, x_a):
        batch_size = x_a.shape[0]
        params = self.net(x_a)  # [batch, out_features * (2K + K - 1)]
        
        # Reshape to [batch, out_features, params_per_dim]
        params = params.view(batch_size, -1, self.bins * 2 + self.bins - 1)
        
        # Split into widths, heights, derivatives
        widths = params[..., :self.bins]
        heights = params[..., self.bins:2*self.bins]
        derivatives = params[..., 2*self.bins:]
        
        # Create independent RQS for each dimension
        # Note: This creates a dependent transform wrapper
        return zuko.transforms.DependentTransform(
            zuko.transforms.MonotonicRQSTransform(
                widths, heights, derivatives,
                bound=5.0
            ),
            reinterpreted=1
        )

# Create coupling with RQS
features = 16
mask = torch.zeros(features, dtype=torch.bool)
mask[:features // 2] = True

net = RQSCouplingNet(
    in_features=mask.sum(),
    out_features=(~mask).sum(),
    bins=8
)

transform = zuko.transforms.CouplingTransform(meta=net, mask=mask)

# Apply to data
x = torch.randn(64, features)
y, ladj = transform.call_and_ladj(x)
print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [64, 16], [64]

Conditional RQS Flow

import torch
import torch.nn as nn
import zuko

class ConditionalRQSFlow(nn.Module):
    """Flow with RQS transformations conditioned on context."""
    def __init__(self, features: int, context_dim: int, bins: int = 8, layers: int = 3):
        super().__init__()
        self.features = features
        self.context_nets = nn.ModuleList()
        self.transforms = []
        
        for i in range(layers):
            # Network to produce RQS parameters from context
            net = nn.Sequential(
                nn.Linear(context_dim, 128),
                nn.ReLU(),
                nn.Linear(128, features * (bins * 2 + bins - 1)),
            )
            self.context_nets.append(net)
            
            # Mask for coupling
            mask = torch.zeros(features, dtype=torch.bool)
            if i % 2 == 0:
                mask[:features // 2] = True
            else:
                mask[features // 2:] = True
            
            self.transforms.append(mask)
    
    def forward(self, x, context):
        """Apply conditional transformation."""
        ladj_total = torch.zeros(x.shape[0], device=x.device)
        
        for net, mask in zip(self.context_nets, self.transforms):
            # Get RQS parameters from context
            params = net(context)
            params = params.view(x.shape[0], self.features, -1)
            
            bins = (params.shape[-1] - 1) // 3
            widths = params[..., :bins]
            heights = params[..., bins:2*bins]
            derivatives = params[..., 2*bins:]
            
            # Split input
            x_a = x[..., mask]
            x_b = x[..., ~mask]
            
            # Get parameters for transformed dimensions
            widths_b = widths[..., ~mask, :]
            heights_b = heights[..., ~mask, :]
            derivatives_b = derivatives[..., ~mask, :]
            
            # Apply RQS transformation
            rqs = zuko.transforms.DependentTransform(
                zuko.transforms.MonotonicRQSTransform(
                    widths_b, heights_b, derivatives_b
                ),
                reinterpreted=1
            )
            
            y_b, ladj = rqs.call_and_ladj(x_b)
            ladj_total = ladj_total + ladj
            
            # Merge back
            y = torch.empty_like(x)
            y[..., mask] = x_a
            y[..., ~mask] = y_b
            x = y
            
        return x, ladj_total

# Create and use conditional flow
flow = ConditionalRQSFlow(features=10, context_dim=5, bins=8, layers=3)

x = torch.randn(32, 10)
context = torch.randn(32, 5)

y, ladj = flow(x, context)
print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 10], [32]

Visualizing RQS Transformation

import torch
import matplotlib.pyplot as plt
import zuko

# Create a simple RQS
K = 8
widths = torch.randn(1, K) * 0.5
heights = torch.randn(1, K) * 0.5
derivatives = torch.randn(1, K - 1) * 0.5

transform = zuko.transforms.MonotonicRQSTransform(
    widths, heights, derivatives,
    bound=3.0
)

# Evaluate on a grid
x = torch.linspace(-4, 4, 1000).unsqueeze(-1)
y = transform(x)

# Plot transformation
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.plot(x.squeeze(), y.squeeze(), label='f(x)')
plt.plot(x.squeeze(), x.squeeze(), 'k--', alpha=0.3, label='Identity')
plt.axvline(-3, color='r', linestyle='--', alpha=0.5, label='Bounds')
plt.axvline(3, color='r', linestyle='--', alpha=0.5)
plt.xlabel('x')
plt.ylabel('y')
plt.title('RQS Transformation')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
ladj = transform.log_abs_det_jacobian(x, y)
plt.plot(x.squeeze(), ladj.squeeze())
plt.xlabel('x')
plt.ylabel('log |df/dx|')
plt.title('Log Absolute Jacobian Determinant')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Key Considerations

Number of Bins

The number of bins KK controls expressiveness:
  • K = 4-8: Good for most applications, efficient
  • K = 16-32: Very expressive, more parameters
  • K > 32: Diminishing returns, slower

Bound Selection

The bound BB determines the spline’s domain:
  • Should cover the typical range of your data
  • Too small: Most data falls in linear region
  • Too large: Wasted capacity, numerical issues
  • Typical: 3.0 - 5.0 for normalized data

Minimum Slope

The slope parameter ensures numerical stability:
  • Default 1e-3 works well for most cases
  • Smaller values: More flexible but less stable
  • Larger values: More stable but less expressive

Computational Cost

RQS transforms are more expensive than affine transforms:
  • Forward/inverse: O(K)O(K) per evaluation (bin search + quadratic solve)
  • Parameters: 3K13K - 1 per dimension
  • Trade-off: Expressiveness vs speed

Advantages

  1. Highly expressive: Can approximate complex monotonic functions
  2. Smooth: C1C^1 continuous everywhere
  3. Guaranteed invertible: Monotonicity ensures bijectivity
  4. Bounded: Well-defined behavior over entire domain
  5. Efficient inverse: Closed-form solution using quadratic formula

References

Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G. (2019). Neural Spline Flows.
https://arxiv.org/abs/1906.04032

Build docs developers (and LLMs) love