Skip to main content
The CouplingTransform implements a coupling scheme that splits inputs into two parts and applies a transformation to one part conditioned on the other.

Mathematical Formulation

The coupling transformation splits the input xx into two parts (xa,xb)(x_a, x_b) and transforms only one part: ya=xayb=f(xbxa)\begin{align} y_a &= x_a \\\\ y_b &= f(x_b | x_a) \end{align} where ff is a conditional transformation parameterized by xax_a. The split is defined by a binary mask.

Class Definition

class CouplingTransform(Transform)
Transform via a coupling scheme.
meta
Callable[[Tensor], Transform]
A function which returns a transformation ff given xax_a. This meta-function receives the constant part and produces a transformation for the transformed part.
mask
BoolTensor
A coupling mask defining the split x(xa,xb)x \to (x_a, x_b). Ones correspond to the constant split xax_a and zeros to the transformed split xbx_b.

Properties

  • Domain: constraints.real_vector
  • Codomain: constraints.real_vector
  • Bijective: True

Implementation Details

Splitting and Merging

The transform maintains indices for efficient splitting:
self.idx_a = mask.nonzero().squeeze(-1)  # Indices where mask is True
self.idx_b = (~mask).nonzero().squeeze(-1)  # Indices where mask is False

Forward Pass

The forward pass keeps xax_a constant and transforms xbx_b:
def _call(self, x: Tensor) -> Tensor:
    x_a, x_b = self.split(x)
    y_b = self.meta(x_a)(x_b)
    return self.merge(x_a, y_b, x.shape)

Inverse Pass

The inverse is computed efficiently in a single pass:
def _inverse(self, y: Tensor) -> Tensor:
    y_a, y_b = self.split(y)
    x_b = self.meta(y_a).inv(y_b)
    return self.merge(y_a, x_b, y.shape)
This is more efficient than autoregressive transforms because ya=xay_a = x_a is known.

Usage Examples

Basic Coupling Transform

import torch
import torch.nn as nn
import zuko

class CouplingNet(nn.Module):
    """Neural network for coupling transformation."""
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, out_features * 2),  # shift and scale
        )
        
    def forward(self, x_a):
        params = self.net(x_a)
        shift, scale = params.chunk(2, dim=-1)
        return zuko.transforms.MonotonicAffineTransform(shift, scale)

# Create coupling mask (split at midpoint)
features = 10
mask = torch.zeros(features, dtype=torch.bool)
mask[:features // 2] = True  # First half is constant

# Create coupling transform
net = CouplingNet(in_features=mask.sum(), out_features=(~mask).sum())
transform = zuko.transforms.CouplingTransform(meta=net, mask=mask)

# Apply transformation
x = torch.randn(32, features)
y = transform(x)

# Inverse is efficient (single pass)
x_reconstructed = transform.inv(y)

# Log determinant
ladj = transform.log_abs_det_jacobian(x, y)
print(f"Log determinant shape: {ladj.shape}")  # [32]

Checkerboard Coupling Pattern

import torch
import zuko

# Create checkerboard mask for 2D data (e.g., images)
def checkerboard_mask(height, width, flip=False):
    """Create checkerboard mask for spatial coupling."""
    row_idx = torch.arange(height).view(-1, 1)
    col_idx = torch.arange(width).view(1, -1)
    mask = ((row_idx + col_idx) % 2 == 0)
    if flip:
        mask = ~mask
    return mask.flatten()

# Create coupling with checkerboard pattern
mask = checkerboard_mask(8, 8, flip=False)
net = CouplingNet(in_features=mask.sum(), out_features=(~mask).sum())
transform = zuko.transforms.CouplingTransform(meta=net, mask=mask)

# Apply to image-like data
x = torch.randn(16, 64)  # 16 images, 8x8 flattened
y = transform(x)
print(f"Transformed shape: {y.shape}")  # [16, 64]

Multi-Scale Coupling Flow

import torch
import torch.nn as nn
import zuko

class CouplingFlow(nn.Module):
    """Multi-layer coupling flow with alternating masks."""
    def __init__(self, features: int, layers: int = 4):
        super().__init__()
        self.transforms = []
        
        for i in range(layers):
            # Alternate coupling mask
            mask = torch.zeros(features, dtype=torch.bool)
            if i % 2 == 0:
                mask[:features // 2] = True  # First half constant
            else:
                mask[features // 2:] = True  # Second half constant
            
            # Create coupling network
            n_const = mask.sum().item()
            n_transform = (~mask).sum().item()
            net = CouplingNet(in_features=n_const, out_features=n_transform)
            
            # Add coupling transform
            coupling = zuko.transforms.CouplingTransform(meta=net, mask=mask)
            self.transforms.append(coupling)
        
        # Compose all transforms
        self.flow = zuko.transforms.ComposedTransform(*self.transforms)
        
    def forward(self, x):
        return self.flow(x)
    
    def log_prob(self, x, base_dist):
        """Compute log probability under the flow."""
        y, ladj = self.flow.call_and_ladj(x)
        log_prob = base_dist.log_prob(y) + ladj
        return log_prob

# Create and use coupling flow
flow = CouplingFlow(features=20, layers=4)
base_dist = torch.distributions.MultivariateNormal(
    torch.zeros(20),
    torch.eye(20)
)

x = torch.randn(64, 20)
log_prob = flow.log_prob(x, base_dist)
print(f"Log probability shape: {log_prob.shape}")  # [64]

RealNVP-Style Coupling

import torch
import torch.nn as nn
import zuko

class RealNVPCouplingNet(nn.Module):
    """RealNVP-style coupling network with expressive transformations."""
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_features * 3),  # For RQS: widths, heights, derivatives
        )
        
    def forward(self, x_a):
        params = self.net(x_a)
        # Split into parameters for rational quadratic spline
        widths, heights, derivatives = params.chunk(3, dim=-1)
        return zuko.transforms.MonotonicRQSTransform(
            widths, heights, derivatives,
            bound=5.0
        )

# Create RealNVP-style flow
features = 16
mask = torch.zeros(features, dtype=torch.bool)
mask[:features // 2] = True

net = RealNVPCouplingNet(
    in_features=mask.sum(),
    out_features=(~mask).sum()
)
transform = zuko.transforms.CouplingTransform(meta=net, mask=mask)

# Apply expressive transformation
x = torch.randn(32, features)
y, ladj = transform.call_and_ladj(x)
print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 16], [32]

Key Considerations

Mask Design

The coupling mask determines expressiveness:
  • Split at midpoint: Simple, balanced computation
  • Alternating patterns: Better mixing between layers
  • Checkerboard: Good for spatial data (images)
  • Random: Can be effective but less interpretable

Inverse Efficiency

Coupling transforms have an efficient inverse because:
  1. The constant part is unchanged: ya=xay_a = x_a
  2. The inverse only needs to invert ff: xb=f1(ybya)x_b = f^{-1}(y_b | y_a)
  3. No iterative refinement is needed
This makes coupling transforms ideal for:
  • Sampling from learned distributions
  • Variational inference
  • Applications requiring fast inverse computation

Expressiveness vs Autoregressive

Coupling advantages:
  • Fast inverse (single pass)
  • Parallelizable computation
  • Stable gradients
Autoregressive advantages:
  • More expressive per layer
  • Full autoregressive conditioning
Best practice: Use multiple coupling layers with alternating masks to achieve high expressiveness while maintaining efficiency.

References

Dinh, L., Krueger, D., & Bengio, Y. (2014). NICE: Non-linear Independent Components Estimation.
https://arxiv.org/abs/1410.8516
Dinh, L., Sohl-Dickstein, J., & Bengio, S. (2016). Density estimation using Real NVP.
https://arxiv.org/abs/1605.08803
Kingma, D. P., & Dhariwal, P. (2018). Glow: Generative Flow using Invertible 1x1 Convolutions.
https://arxiv.org/abs/1807.03039

Build docs developers (and LLMs) love