Skip to main content
Utility transforms provide essential building blocks for constructing normalizing flows, including identity, affine, permutation, and composition operations.

Composition and Structure

ComposedTransform

Creates a composition of multiple transformations.
class ComposedTransform(Transform)
Mathematical Formulation: f(x)=fnf0(x)f(x) = f_n \circ \cdots \circ f_0(x) This is an optimized version of PyTorch’s ComposeTransform with better event dimension handling.
transforms
Transform
A sequence of transformations fif_i to compose. Provided as variadic arguments.
Usage Example:
import torch
import zuko

# Create individual transforms
t1 = zuko.transforms.AdditiveTransform(torch.tensor([1.0, 2.0]))
t2 = zuko.transforms.PermutationTransform(torch.tensor([1, 0]))
t3 = zuko.transforms.MonotonicAffineTransform(
    torch.tensor([0.0, 0.0]),
    torch.tensor([1.0, 1.0])
)

# Compose transforms
transform = zuko.transforms.ComposedTransform(t1, t2, t3)

# Apply composed transformation
x = torch.randn(32, 2)
y, ladj = transform.call_and_ladj(x)

print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 2], [32]

# Inverse
x_reconstructed = transform.inv(y)
print(f"Reconstruction error: {(x - x_reconstructed).abs().max()}")

DependentTransform

Wraps a base transformation to treat right-most dimensions as dependent.
class DependentTransform(Transform)
This is an optimized version of PyTorch’s IndependentTransform.
base
Transform
The base transformation to wrap.
reinterpreted
int
The number of dimensions to treat as dependent (summed in log determinant).
Usage Example:
import torch
import zuko

# Create a per-element transformation
shift = torch.randn(1, 10)  # Per-dimension shifts
scale = torch.randn(1, 10)
base_transform = zuko.transforms.MonotonicAffineTransform(shift, scale)

# Wrap to treat all 10 dimensions as dependent
transform = zuko.transforms.DependentTransform(base_transform, reinterpreted=1)

# Apply transformation
x = torch.randn(32, 10)
y, ladj = transform.call_and_ladj(x)

print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 10], [32] (summed over dims)

Basic Transforms

IdentityTransform

The identity transformation f(x)=xf(x) = x.
class IdentityTransform(Transform)
Mathematical Formulation: f(x)=x,logdetJf=0f(x) = x, \quad \log |\det J_f| = 0 Usage Example:
import torch
import zuko

transform = zuko.transforms.IdentityTransform()

x = torch.randn(32, 5)
y = transform(x)
assert torch.equal(x, y)  # y is exactly x

ladj = transform.log_abs_det_jacobian(x, y)
assert torch.all(ladj == 0)  # log determinant is zero

AdditiveTransform

Translation transformation.
class AdditiveTransform(Transform)
Mathematical Formulation: f(x)=x+b,logdetJf=0f(x) = x + b, \quad \log |\det J_f| = 0
shift
Tensor
The shift term bb, with shape (,)(*, ).
Usage Example:
import torch
import zuko

shift = torch.tensor([1.0, -2.0, 3.0])
transform = zuko.transforms.AdditiveTransform(shift)

x = torch.randn(32, 3)
y = transform(x)

assert torch.allclose(y, x + shift)

# Used in NICE architecture
mask = torch.zeros(6, dtype=torch.bool)
mask[:3] = True

class NICENet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(3, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 3),
        )
    
    def forward(self, x_a):
        shift = self.net(x_a)
        return zuko.transforms.AdditiveTransform(shift)

net = NICENet()
coupling = zuko.transforms.CouplingTransform(meta=net, mask=mask)

MonotonicAffineTransform

Affine transformation with positive scale.
class MonotonicAffineTransform(Transform)
Mathematical Formulation: f(x)=exp(a)x+b,logdetJf=af(x) = \exp(a) \cdot x + b, \quad \log |\det J_f| = a
shift
Tensor
The shift term bb, with shape (,)(*, ).
scale
Tensor
The unconstrained scale factor aa, with shape (,)(*, ).
slope
float
default:"1e-3"
The minimum slope of the transformation.
Usage Example:
import torch
import zuko

shift = torch.tensor([1.0, 2.0])
scale = torch.tensor([0.5, -0.3])  # Unconstrained

transform = zuko.transforms.MonotonicAffineTransform(shift, scale, slope=1e-3)

x = torch.randn(32, 2)
y, ladj = transform.call_and_ladj(x)

print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 2], [32, 2]

Linear Transforms

PermutationTransform

Permutes elements according to a specified order.
class PermutationTransform(Transform)
order
LongTensor
The permutation order, with shape (,D)(*, D).
Usage Example:
import torch
import zuko

# Reverse ordering
order = torch.tensor([4, 3, 2, 1, 0])
transform = zuko.transforms.PermutationTransform(order)

x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
y = transform(x)
print(y)  # tensor([[5.0, 4.0, 3.0, 2.0, 1.0]])

# Random permutation for flow
features = 10
random_order = torch.randperm(features)
transform = zuko.transforms.PermutationTransform(random_order)

# Use between coupling layers
transforms = [
    coupling_layer_1,
    zuko.transforms.PermutationTransform(torch.randperm(features)),
    coupling_layer_2,
    zuko.transforms.PermutationTransform(torch.randperm(features)),
    coupling_layer_3,
]
flow = zuko.transforms.ComposedTransform(*transforms)

RotationTransform

Orthogonal rotation transformation.
class RotationTransform(Transform)
Mathematical Formulation: f(x)=Rx,R=exp(AAT)f(x) = R x, \quad R = \exp(A - A^T) Because AATA - A^T is skew-symmetric, RR is orthogonal, ensuring logdetR=0\log |\det R| = 0.
A
Tensor
A square matrix AA, with shape (,D,D)(*, D, D).
Usage Example:
import torch
import zuko

# Create rotation matrix
D = 3
A = torch.randn(D, D)
transform = zuko.transforms.RotationTransform(A)

x = torch.randn(32, D)
y = transform(x)

# Check orthogonality: R^T R = I
R = transform.R
identity = R.T @ R
print(f"Orthogonality error: {(identity - torch.eye(D)).abs().max()}")  # ~0

# Log determinant is zero for orthogonal matrix
ladj = transform.log_abs_det_jacobian(x, y)
assert torch.allclose(ladj, torch.zeros_like(ladj))

LULinearTransform

Linear transformation using LU decomposition.
class LULinearTransform(Transform)
Mathematical Formulation: f(x)=LUx,logdetJf=ilogLiif(x) = LUx, \quad \log |\det J_f| = \sum_i \log |L_{ii}|
LU
Tensor
A matrix whose lower and upper triangular parts are the non-zero elements of LL and UU, with shape (,D,D)(*, D, D).
Usage Example:
import torch
import zuko

# Create LU transform
D = 4
LU = torch.randn(D, D)
transform = zuko.transforms.LULinearTransform(LU)

x = torch.randn(32, D)
y, ladj = transform.call_and_ladj(x)

print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 4], [32]

# Used in Glow architecture
class ConvLU(torch.nn.Module):
    def __init__(self, features):
        super().__init__()
        self.LU = torch.nn.Parameter(torch.randn(features, features))
    
    def forward(self):
        return zuko.transforms.LULinearTransform(self.LU)

conv_lu = ConvLU(features=8)
transform = conv_lu()

Trigonometric Transforms

CosTransform

Cosine transformation.
class CosTransform(Transform)
Mathematical Formulation: f(x)=cos(x),x[0,π],y[1,1]f(x) = -\cos(x), \quad x \in [0, \pi], \quad y \in [-1, 1] Usage Example:
import torch
import zuko

transform = zuko.transforms.CosTransform()

x = torch.linspace(0, 3.14159, 100).unsqueeze(-1)
y = transform(x)

ladj = transform.log_abs_det_jacobian(x, y)
print(f"LADJ: {ladj.shape}")  # [100, 1]

SinTransform

Sine transformation.
class SinTransform(Transform)
Mathematical Formulation: f(x)=sin(x),x[π/2,π/2],y[1,1]f(x) = \sin(x), \quad x \in [-\pi/2, \pi/2], \quad y \in [-1, 1] Usage Example:
import torch
import zuko

transform = zuko.transforms.SinTransform()

x = torch.linspace(-1.5708, 1.5708, 100).unsqueeze(-1)
y = transform(x)

ladj = transform.log_abs_det_jacobian(x, y)

Bounded Transforms

SoftclipTransform

Smooth clipping to bounded interval.
class SoftclipTransform(Transform)
Mathematical Formulation: f(x)=x1+x/B,R[B,B]f(x) = \frac{x}{1 + |x/B|}, \quad \mathbb{R} \to [-B, B]
bound
float
default:"1.0"
The codomain bound BB.
Usage Example:
import torch
import zuko
import matplotlib.pyplot as plt

transform = zuko.transforms.SoftclipTransform(bound=3.0)

x = torch.linspace(-10, 10, 200).unsqueeze(-1)
y = transform(x)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(x.squeeze(), y.squeeze())
plt.axhline(3.0, color='r', linestyle='--', alpha=0.5)
plt.axhline(-3.0, color='r', linestyle='--', alpha=0.5)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Softclip Transformation')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
ladj = transform.log_abs_det_jacobian(x, y)
plt.plot(x.squeeze(), ladj.squeeze())
plt.xlabel('x')
plt.ylabel('log |df/dx|')
plt.title('Log Determinant')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

CircularShiftTransform

Circular shift on bounded interval.
class CircularShiftTransform(Transform)
Mathematical Formulation: f(x)=(xmod2B)B,x[B,B]f(x) = (x \bmod 2B) - B, \quad x \in [-B, B]
bound
float
default:"1.0"
The domain bound BB.
Usage Example:
import torch
import zuko

transform = zuko.transforms.CircularShiftTransform(bound=2.0)

x = torch.tensor([[-1.5], [0.0], [1.5], [1.9]])
y = transform(x)

print(f"x: {x.squeeze()}")
print(f"y: {y.squeeze()}")

# Inverse is the same operation (circular)
x_reconstructed = transform.inv(y)
assert torch.allclose(x, x_reconstructed)

SignedPowerTransform

Signed power transformation.
class SignedPowerTransform(Transform)
Mathematical Formulation: f(x)=sign(x)xexp(α)f(x) = \text{sign}(x) |x|^{\exp(\alpha)}
alpha
Tensor
The unconstrained exponent α\alpha, with shape (,)(*, ).
Usage Example:
import torch
import zuko

# Create signed power transforms with different exponents
alpha = torch.tensor([0.0, 0.5, -0.5])  # exp(alpha) = [1.0, 1.65, 0.61]
transform = zuko.transforms.SignedPowerTransform(alpha)

x = torch.randn(32, 3)
y, ladj = transform.call_and_ladj(x)

print(f"Output: {y.shape}, LADJ: {ladj.shape}")  # [32, 3], [32, 3]

# Visualize effect
import matplotlib.pyplot as plt

alphas = [torch.tensor([a]) for a in [-1.0, -0.5, 0.0, 0.5, 1.0]]
x_test = torch.linspace(-3, 3, 100).unsqueeze(-1)

plt.figure(figsize=(10, 6))
for alpha in alphas:
    t = zuko.transforms.SignedPowerTransform(alpha)
    y_test = t(x_test)
    plt.plot(x_test.squeeze(), y_test.squeeze(), label=f'α={alpha.item():.1f}')

plt.plot(x_test.squeeze(), x_test.squeeze(), 'k--', alpha=0.3, label='Identity')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Signed Power Transformations')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Complete Flow Example

Here’s a complete example combining multiple utility transforms:
import torch
import torch.nn as nn
import zuko

class CompleteFlow(nn.Module):
    """Flow combining various utility transforms."""
    def __init__(self, features: int, layers: int = 4):
        super().__init__()
        self.transforms = []
        
        for i in range(layers):
            # Coupling layer
            mask = torch.zeros(features, dtype=torch.bool)
            if i % 2 == 0:
                mask[:features // 2] = True
            else:
                mask[features // 2:] = True
            
            class AffineNet(nn.Module):
                def __init__(self, in_dim, out_dim):
                    super().__init__()
                    self.net = nn.Sequential(
                        nn.Linear(in_dim, 64),
                        nn.ReLU(),
                        nn.Linear(64, out_dim * 2)
                    )
                
                def forward(self, x_a):
                    params = self.net(x_a)
                    shift, scale = params.chunk(2, dim=-1)
                    return zuko.transforms.MonotonicAffineTransform(shift, scale)
            
            net = AffineNet(mask.sum(), (~mask).sum())
            coupling = zuko.transforms.CouplingTransform(meta=net, mask=mask)
            self.transforms.append(coupling)
            
            # Random permutation (except last layer)
            if i < layers - 1:
                perm = torch.randperm(features)
                self.transforms.append(
                    zuko.transforms.PermutationTransform(perm)
                )
        
        # Compose all transforms
        self.flow = zuko.transforms.ComposedTransform(*self.transforms)
        
        # Base distribution
        self.register_buffer('base_loc', torch.zeros(features))
        self.register_buffer('base_scale', torch.ones(features))
    
    def forward(self, x):
        z, ladj = self.flow.call_and_ladj(x)
        base_dist = torch.distributions.Normal(self.base_loc, self.base_scale)
        log_prob = base_dist.log_prob(z).sum(dim=-1) + ladj
        return log_prob
    
    def sample(self, n_samples):
        base_dist = torch.distributions.Normal(self.base_loc, self.base_scale)
        z = base_dist.sample((n_samples,))
        x = self.flow.inv(z)
        return x

# Create and use flow
flow = CompleteFlow(features=10, layers=4)

# Training
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
data = torch.randn(128, 10)  # Your data here

for epoch in range(100):
    optimizer.zero_grad()
    log_prob = flow(data)
    loss = -log_prob.mean()
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Sampling
with torch.no_grad():
    samples = flow.sample(1000)
    print(f"Generated samples: {samples.shape}")  # [1000, 10]

Summary Table

TransformOperationDomainCodomainLADJ
Identityf(x)=xf(x) = xR\mathbb{R}R\mathbb{R}0
Additivef(x)=x+bf(x) = x + bR\mathbb{R}R\mathbb{R}0
MonotonicAffinef(x)=eax+bf(x) = e^a x + bR\mathbb{R}R\mathbb{R}aa
PermutationReorder elementsRD\mathbb{R}^DRD\mathbb{R}^D0
Rotationf(x)=Rxf(x) = RxRD\mathbb{R}^DRD\mathbb{R}^D0
LULinearf(x)=LUxf(x) = LUxRD\mathbb{R}^DRD\mathbb{R}^D$\sum \logL_$
SoftclipSmooth clippingR\mathbb{R}[B,B][-B, B]Complex
SignedPowersign(x)xexp(α)\text{sign}(x) \cdot \|x\|^{\exp(\alpha)}R\mathbb{R}R\mathbb{R}Complex

References

Dinh, L., Krueger, D., & Bengio, Y. (2014).
https://arxiv.org/abs/1410.8516

Build docs developers (and LLMs) love