Utility transforms provide essential building blocks for constructing normalizing flows, including identity, affine, permutation, and composition operations.
Composition and Structure
Creates a composition of multiple transformations.
class ComposedTransform ( Transform )
Mathematical Formulation:
f ( x ) = f n ∘ ⋯ ∘ f 0 ( x ) f(x) = f_n \circ \cdots \circ f_0(x) f ( x ) = f n ∘ ⋯ ∘ f 0 ( x )
This is an optimized version of PyTorch’s ComposeTransform with better event dimension handling.
A sequence of transformations f i f_i f i to compose. Provided as variadic arguments.
Usage Example:
import torch
import zuko
# Create individual transforms
t1 = zuko.transforms.AdditiveTransform(torch.tensor([ 1.0 , 2.0 ]))
t2 = zuko.transforms.PermutationTransform(torch.tensor([ 1 , 0 ]))
t3 = zuko.transforms.MonotonicAffineTransform(
torch.tensor([ 0.0 , 0.0 ]),
torch.tensor([ 1.0 , 1.0 ])
)
# Compose transforms
transform = zuko.transforms.ComposedTransform(t1, t2, t3)
# Apply composed transformation
x = torch.randn( 32 , 2 )
y, ladj = transform.call_and_ladj(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 2], [32]
# Inverse
x_reconstructed = transform.inv(y)
print ( f "Reconstruction error: { (x - x_reconstructed).abs().max() } " )
Wraps a base transformation to treat right-most dimensions as dependent.
class DependentTransform ( Transform )
This is an optimized version of PyTorch’s IndependentTransform.
The base transformation to wrap.
The number of dimensions to treat as dependent (summed in log determinant).
Usage Example:
import torch
import zuko
# Create a per-element transformation
shift = torch.randn( 1 , 10 ) # Per-dimension shifts
scale = torch.randn( 1 , 10 )
base_transform = zuko.transforms.MonotonicAffineTransform(shift, scale)
# Wrap to treat all 10 dimensions as dependent
transform = zuko.transforms.DependentTransform(base_transform, reinterpreted = 1 )
# Apply transformation
x = torch.randn( 32 , 10 )
y, ladj = transform.call_and_ladj(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 10], [32] (summed over dims)
The identity transformation f ( x ) = x f(x) = x f ( x ) = x .
class IdentityTransform ( Transform )
Mathematical Formulation:
f ( x ) = x , log ∣ det J f ∣ = 0 f(x) = x, \quad \log |\det J_f| = 0 f ( x ) = x , log ∣ det J f ∣ = 0
Usage Example:
import torch
import zuko
transform = zuko.transforms.IdentityTransform()
x = torch.randn( 32 , 5 )
y = transform(x)
assert torch.equal(x, y) # y is exactly x
ladj = transform.log_abs_det_jacobian(x, y)
assert torch.all(ladj == 0 ) # log determinant is zero
Translation transformation.
class AdditiveTransform ( Transform )
Mathematical Formulation:
f ( x ) = x + b , log ∣ det J f ∣ = 0 f(x) = x + b, \quad \log |\det J_f| = 0 f ( x ) = x + b , log ∣ det J f ∣ = 0
The shift term b b b , with shape ( ∗ , ) (*, ) ( ∗ , ) .
Usage Example:
import torch
import zuko
shift = torch.tensor([ 1.0 , - 2.0 , 3.0 ])
transform = zuko.transforms.AdditiveTransform(shift)
x = torch.randn( 32 , 3 )
y = transform(x)
assert torch.allclose(y, x + shift)
# Used in NICE architecture
mask = torch.zeros( 6 , dtype = torch.bool)
mask[: 3 ] = True
class NICENet ( torch . nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self .net = torch.nn.Sequential(
torch.nn.Linear( 3 , 64 ),
torch.nn.ReLU(),
torch.nn.Linear( 64 , 3 ),
)
def forward ( self , x_a ):
shift = self .net(x_a)
return zuko.transforms.AdditiveTransform(shift)
net = NICENet()
coupling = zuko.transforms.CouplingTransform( meta = net, mask = mask)
Affine transformation with positive scale.
class MonotonicAffineTransform ( Transform )
Mathematical Formulation:
f ( x ) = exp ( a ) ⋅ x + b , log ∣ det J f ∣ = a f(x) = \exp(a) \cdot x + b, \quad \log |\det J_f| = a f ( x ) = exp ( a ) ⋅ x + b , log ∣ det J f ∣ = a
The shift term b b b , with shape ( ∗ , ) (*, ) ( ∗ , ) .
The unconstrained scale factor a a a , with shape ( ∗ , ) (*, ) ( ∗ , ) .
The minimum slope of the transformation.
Usage Example:
import torch
import zuko
shift = torch.tensor([ 1.0 , 2.0 ])
scale = torch.tensor([ 0.5 , - 0.3 ]) # Unconstrained
transform = zuko.transforms.MonotonicAffineTransform(shift, scale, slope = 1e-3 )
x = torch.randn( 32 , 2 )
y, ladj = transform.call_and_ladj(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 2], [32, 2]
Permutes elements according to a specified order.
class PermutationTransform ( Transform )
The permutation order, with shape ( ∗ , D ) (*, D) ( ∗ , D ) .
Usage Example:
import torch
import zuko
# Reverse ordering
order = torch.tensor([ 4 , 3 , 2 , 1 , 0 ])
transform = zuko.transforms.PermutationTransform(order)
x = torch.tensor([[ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]])
y = transform(x)
print (y) # tensor([[5.0, 4.0, 3.0, 2.0, 1.0]])
# Random permutation for flow
features = 10
random_order = torch.randperm(features)
transform = zuko.transforms.PermutationTransform(random_order)
# Use between coupling layers
transforms = [
coupling_layer_1,
zuko.transforms.PermutationTransform(torch.randperm(features)),
coupling_layer_2,
zuko.transforms.PermutationTransform(torch.randperm(features)),
coupling_layer_3,
]
flow = zuko.transforms.ComposedTransform( * transforms)
Orthogonal rotation transformation.
class RotationTransform ( Transform )
Mathematical Formulation:
f ( x ) = R x , R = exp ( A − A T ) f(x) = R x, \quad R = \exp(A - A^T) f ( x ) = R x , R = exp ( A − A T )
Because A − A T A - A^T A − A T is skew-symmetric, R R R is orthogonal, ensuring log ∣ det R ∣ = 0 \log |\det R| = 0 log ∣ det R ∣ = 0 .
A square matrix A A A , with shape ( ∗ , D , D ) (*, D, D) ( ∗ , D , D ) .
Usage Example:
import torch
import zuko
# Create rotation matrix
D = 3
A = torch.randn(D, D)
transform = zuko.transforms.RotationTransform(A)
x = torch.randn( 32 , D)
y = transform(x)
# Check orthogonality: R^T R = I
R = transform.R
identity = R.T @ R
print ( f "Orthogonality error: { (identity - torch.eye(D)).abs().max() } " ) # ~0
# Log determinant is zero for orthogonal matrix
ladj = transform.log_abs_det_jacobian(x, y)
assert torch.allclose(ladj, torch.zeros_like(ladj))
Linear transformation using LU decomposition.
class LULinearTransform ( Transform )
Mathematical Formulation:
f ( x ) = L U x , log ∣ det J f ∣ = ∑ i log ∣ L i i ∣ f(x) = LUx, \quad \log |\det J_f| = \sum_i \log |L_{ii}| f ( x ) = LUx , log ∣ det J f ∣ = i ∑ log ∣ L ii ∣
A matrix whose lower and upper triangular parts are the non-zero elements of L L L and U U U , with shape ( ∗ , D , D ) (*, D, D) ( ∗ , D , D ) .
Usage Example:
import torch
import zuko
# Create LU transform
D = 4
LU = torch.randn(D, D)
transform = zuko.transforms.LULinearTransform( LU )
x = torch.randn( 32 , D)
y, ladj = transform.call_and_ladj(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 4], [32]
# Used in Glow architecture
class ConvLU ( torch . nn . Module ):
def __init__ ( self , features ):
super (). __init__ ()
self . LU = torch.nn.Parameter(torch.randn(features, features))
def forward ( self ):
return zuko.transforms.LULinearTransform( self . LU )
conv_lu = ConvLU( features = 8 )
transform = conv_lu()
Cosine transformation.
class CosTransform ( Transform )
Mathematical Formulation:
f ( x ) = − cos ( x ) , x ∈ [ 0 , π ] , y ∈ [ − 1 , 1 ] f(x) = -\cos(x), \quad x \in [0, \pi], \quad y \in [-1, 1] f ( x ) = − cos ( x ) , x ∈ [ 0 , π ] , y ∈ [ − 1 , 1 ]
Usage Example:
import torch
import zuko
transform = zuko.transforms.CosTransform()
x = torch.linspace( 0 , 3.14159 , 100 ).unsqueeze( - 1 )
y = transform(x)
ladj = transform.log_abs_det_jacobian(x, y)
print ( f "LADJ: { ladj.shape } " ) # [100, 1]
Sine transformation.
class SinTransform ( Transform )
Mathematical Formulation:
f ( x ) = sin ( x ) , x ∈ [ − π / 2 , π / 2 ] , y ∈ [ − 1 , 1 ] f(x) = \sin(x), \quad x \in [-\pi/2, \pi/2], \quad y \in [-1, 1] f ( x ) = sin ( x ) , x ∈ [ − π /2 , π /2 ] , y ∈ [ − 1 , 1 ]
Usage Example:
import torch
import zuko
transform = zuko.transforms.SinTransform()
x = torch.linspace( - 1.5708 , 1.5708 , 100 ).unsqueeze( - 1 )
y = transform(x)
ladj = transform.log_abs_det_jacobian(x, y)
Smooth clipping to bounded interval.
class SoftclipTransform ( Transform )
Mathematical Formulation:
f ( x ) = x 1 + ∣ x / B ∣ , R → [ − B , B ] f(x) = \frac{x}{1 + |x/B|}, \quad \mathbb{R} \to [-B, B] f ( x ) = 1 + ∣ x / B ∣ x , R → [ − B , B ]
Usage Example:
import torch
import zuko
import matplotlib.pyplot as plt
transform = zuko.transforms.SoftclipTransform( bound = 3.0 )
x = torch.linspace( - 10 , 10 , 200 ).unsqueeze( - 1 )
y = transform(x)
plt.figure( figsize = ( 10 , 4 ))
plt.subplot( 1 , 2 , 1 )
plt.plot(x.squeeze(), y.squeeze())
plt.axhline( 3.0 , color = 'r' , linestyle = '--' , alpha = 0.5 )
plt.axhline( - 3.0 , color = 'r' , linestyle = '--' , alpha = 0.5 )
plt.xlabel( 'x' )
plt.ylabel( 'y' )
plt.title( 'Softclip Transformation' )
plt.grid( True , alpha = 0.3 )
plt.subplot( 1 , 2 , 2 )
ladj = transform.log_abs_det_jacobian(x, y)
plt.plot(x.squeeze(), ladj.squeeze())
plt.xlabel( 'x' )
plt.ylabel( 'log |df/dx|' )
plt.title( 'Log Determinant' )
plt.grid( True , alpha = 0.3 )
plt.tight_layout()
plt.show()
Circular shift on bounded interval.
class CircularShiftTransform ( Transform )
Mathematical Formulation:
f ( x ) = ( x m o d 2 B ) − B , x ∈ [ − B , B ] f(x) = (x \bmod 2B) - B, \quad x \in [-B, B] f ( x ) = ( x mod 2 B ) − B , x ∈ [ − B , B ]
Usage Example:
import torch
import zuko
transform = zuko.transforms.CircularShiftTransform( bound = 2.0 )
x = torch.tensor([[ - 1.5 ], [ 0.0 ], [ 1.5 ], [ 1.9 ]])
y = transform(x)
print ( f "x: { x.squeeze() } " )
print ( f "y: { y.squeeze() } " )
# Inverse is the same operation (circular)
x_reconstructed = transform.inv(y)
assert torch.allclose(x, x_reconstructed)
Signed power transformation.
class SignedPowerTransform ( Transform )
Mathematical Formulation:
f ( x ) = sign ( x ) ∣ x ∣ exp ( α ) f(x) = \text{sign}(x) |x|^{\exp(\alpha)} f ( x ) = sign ( x ) ∣ x ∣ e x p ( α )
The unconstrained exponent α \alpha α , with shape ( ∗ , ) (*, ) ( ∗ , ) .
Usage Example:
import torch
import zuko
# Create signed power transforms with different exponents
alpha = torch.tensor([ 0.0 , 0.5 , - 0.5 ]) # exp(alpha) = [1.0, 1.65, 0.61]
transform = zuko.transforms.SignedPowerTransform(alpha)
x = torch.randn( 32 , 3 )
y, ladj = transform.call_and_ladj(x)
print ( f "Output: { y.shape } , LADJ: { ladj.shape } " ) # [32, 3], [32, 3]
# Visualize effect
import matplotlib.pyplot as plt
alphas = [torch.tensor([a]) for a in [ - 1.0 , - 0.5 , 0.0 , 0.5 , 1.0 ]]
x_test = torch.linspace( - 3 , 3 , 100 ).unsqueeze( - 1 )
plt.figure( figsize = ( 10 , 6 ))
for alpha in alphas:
t = zuko.transforms.SignedPowerTransform(alpha)
y_test = t(x_test)
plt.plot(x_test.squeeze(), y_test.squeeze(), label = f 'α= { alpha.item() :.1f} ' )
plt.plot(x_test.squeeze(), x_test.squeeze(), 'k--' , alpha = 0.3 , label = 'Identity' )
plt.xlabel( 'x' )
plt.ylabel( 'y' )
plt.title( 'Signed Power Transformations' )
plt.legend()
plt.grid( True , alpha = 0.3 )
plt.show()
Complete Flow Example
Here’s a complete example combining multiple utility transforms:
import torch
import torch.nn as nn
import zuko
class CompleteFlow ( nn . Module ):
"""Flow combining various utility transforms."""
def __init__ ( self , features : int , layers : int = 4 ):
super (). __init__ ()
self .transforms = []
for i in range (layers):
# Coupling layer
mask = torch.zeros(features, dtype = torch.bool)
if i % 2 == 0 :
mask[:features // 2 ] = True
else :
mask[features // 2 :] = True
class AffineNet ( nn . Module ):
def __init__ ( self , in_dim , out_dim ):
super (). __init__ ()
self .net = nn.Sequential(
nn.Linear(in_dim, 64 ),
nn.ReLU(),
nn.Linear( 64 , out_dim * 2 )
)
def forward ( self , x_a ):
params = self .net(x_a)
shift, scale = params.chunk( 2 , dim =- 1 )
return zuko.transforms.MonotonicAffineTransform(shift, scale)
net = AffineNet(mask.sum(), ( ~ mask).sum())
coupling = zuko.transforms.CouplingTransform( meta = net, mask = mask)
self .transforms.append(coupling)
# Random permutation (except last layer)
if i < layers - 1 :
perm = torch.randperm(features)
self .transforms.append(
zuko.transforms.PermutationTransform(perm)
)
# Compose all transforms
self .flow = zuko.transforms.ComposedTransform( * self .transforms)
# Base distribution
self .register_buffer( 'base_loc' , torch.zeros(features))
self .register_buffer( 'base_scale' , torch.ones(features))
def forward ( self , x ):
z, ladj = self .flow.call_and_ladj(x)
base_dist = torch.distributions.Normal( self .base_loc, self .base_scale)
log_prob = base_dist.log_prob(z).sum( dim =- 1 ) + ladj
return log_prob
def sample ( self , n_samples ):
base_dist = torch.distributions.Normal( self .base_loc, self .base_scale)
z = base_dist.sample((n_samples,))
x = self .flow.inv(z)
return x
# Create and use flow
flow = CompleteFlow( features = 10 , layers = 4 )
# Training
optimizer = torch.optim.Adam(flow.parameters(), lr = 1e-3 )
data = torch.randn( 128 , 10 ) # Your data here
for epoch in range ( 100 ):
optimizer.zero_grad()
log_prob = flow(data)
loss = - log_prob.mean()
loss.backward()
optimizer.step()
if epoch % 10 == 0 :
print ( f "Epoch { epoch } , Loss: { loss.item() :.4f} " )
# Sampling
with torch.no_grad():
samples = flow.sample( 1000 )
print ( f "Generated samples: { samples.shape } " ) # [1000, 10]
Summary Table
Transform Operation Domain Codomain LADJ Identity f ( x ) = x f(x) = x f ( x ) = x R \mathbb{R} R R \mathbb{R} R 0 Additive f ( x ) = x + b f(x) = x + b f ( x ) = x + b R \mathbb{R} R R \mathbb{R} R 0 MonotonicAffine f ( x ) = e a x + b f(x) = e^a x + b f ( x ) = e a x + b R \mathbb{R} R R \mathbb{R} R a a a Permutation Reorder elements R D \mathbb{R}^D R D R D \mathbb{R}^D R D 0 Rotation f ( x ) = R x f(x) = Rx f ( x ) = R x R D \mathbb{R}^D R D R D \mathbb{R}^D R D 0 LULinear f ( x ) = L U x f(x) = LUx f ( x ) = LUx R D \mathbb{R}^D R D R D \mathbb{R}^D R D $\sum \log L_ $ Softclip Smooth clipping R \mathbb{R} R [ − B , B ] [-B, B] [ − B , B ] Complex SignedPower sign ( x ) ⋅ ∥ x ∥ exp ( α ) \text{sign}(x) \cdot \|x\|^{\exp(\alpha)} sign ( x ) ⋅ ∥ x ∥ e x p ( α ) R \mathbb{R} R R \mathbb{R} R Complex
References
NICE: Non-linear Independent Components Estimation