Skip to main content

Overview

Lazy transforms are PyTorch modules that build and return a torch.distributions.Transform within their forward pass, given an optional context. They are the transformation counterpart to lazy distributions and are essential for building conditional normalizing flows.
Lazy transforms solve the same problem as lazy distributions: PyTorch’s Transform objects are not nn.Module instances. Lazy transforms wrap transformations as modules, enabling conditional transformations with trainable parameters.

LazyTransform

The abstract base class for all lazy transformations.
class LazyTransform(nn.Module, abc.ABC)
A lazy transformation builds and returns a transformation y=f(xc)y = f(x | c) within its forward pass, given a context cc.

Methods

forward

def forward(c: Tensor | None = None) -> Transform
Builds and returns a conditional transformation.
c
Tensor | None
A context tensor. If None, the transformation is unconditional.
Returns: A torch.distributions.Transform object representing y=f(xc)y = f(x | c).

Properties

inv

@property
def inv(self) -> LazyTransform
Returns a lazy inverse transformation x=f1(yc)x = f^{-1}(y | c). Returns: A LazyInverse instance that wraps this transformation.

Creating Custom Lazy Transforms

To create a custom lazy transform, subclass LazyTransform and implement the forward method:
import torch
import torch.nn as nn
from torch.distributions.transforms import AffineTransform
from zuko.lazy import LazyTransform

class ConditionalAffine(LazyTransform):
    def __init__(self, features: int, context: int):
        super().__init__()
        self.scale_net = nn.Sequential(
            nn.Linear(context, features),
            nn.Softplus()
        )
        self.shift_net = nn.Linear(context, features)
    
    def forward(self, c: torch.Tensor | None = None) -> AffineTransform:
        if c is None:
            raise ValueError("Context is required")
        
        scale = self.scale_net(c)
        shift = self.shift_net(c)
        
        return AffineTransform(loc=shift, scale=scale)

# Usage
transform = ConditionalAffine(features=3, context=5)
context = torch.randn(10, 5)  # Batch of 10 contexts
f = transform(context)  # Returns an AffineTransform

x = torch.randn(10, 3)
y = f(x)  # Apply transformation: y = f(x|c)
x_reconstructed = f.inv(y)  # Invert: x = f^{-1}(y|c)

UnconditionalTransform

A convenience class for creating unconditional lazy transforms from transformation constructors.
class UnconditionalTransform(LazyTransform)
UnconditionalTransform wraps any transformation constructor and registers its arguments as buffers or parameters.

Constructor

def __init__(
    f: Callable[..., Transform],
    *args,
    buffer: bool = False,
    **kwargs
)
f
Callable[..., Transform]
required
A transformation constructor (e.g., torch.distributions.transforms.ExpTransform or a custom transform class). If f is a module, it is registered as a submodule.
args
Any
Positional arguments passed to f. Tensor arguments are registered as buffers or parameters.
buffer
bool
default:"False"
Whether tensor arguments are registered as buffers (not trainable) or parameters (trainable).
kwargs
Any
Keyword arguments passed to f. Tensor arguments are registered as buffers or parameters.

Methods

forward

def forward(c: Tensor | None = None) -> Transform
Returns the transformation by calling f(*args, **kwargs). The context argument c is always ignored.
c
Tensor | None
A context tensor. This argument is ignored for unconditional transformations.
Returns: f(*args, **kwargs) - the constructed transformation.

Examples

Using with ExpTransform

import torch
from zuko.lazy import UnconditionalTransform
from zuko.transforms import ExpTransform

# Create an unconditional exponential transformation
t = UnconditionalTransform(ExpTransform)

# Get the transformation (context is ignored)
transform = t()
print(transform)  # ExpTransform()

# Apply the transformation
x = torch.randn(3)
y = transform(x)
print(y)  # tensor([4.6692, 0.7457, 0.1132])

Using with RotationTransform

import torch
from zuko.lazy import UnconditionalTransform
from zuko.transforms import RotationTransform

# Create a rotation with a trainable rotation matrix
rotation_matrix = torch.randn(3, 3)

t = UnconditionalTransform(
    RotationTransform,
    rotation_matrix,
    buffer=False  # Make the matrix trainable
)

# Use in a flow
from zuko.lazy import Flow
from zuko.flows.autoregressive import MaskedAutoregressiveTransform

flow = Flow(
    transform=[
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
        t,  # Unconditional rotation
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
    ],
    base=my_base_distribution
)

Creating a fixed transformation

import torch
from torch.distributions.transforms import AffineTransform
from zuko.lazy import UnconditionalTransform

# Create a fixed affine transformation (not trainable)
loc = torch.zeros(5)
scale = torch.ones(5)

t = UnconditionalTransform(
    AffineTransform,
    loc=loc,
    scale=scale,
    buffer=True  # Fixed parameters
)

# The transformation is now part of a module
transform = t()
x = torch.randn(10, 5)
y = transform(x)  # y = scale * x + loc
When buffer=True, tensor arguments are registered as buffers (not trainable). When buffer=False, they become trainable parameters. Choose based on whether you want the transformation to be learned during training.

LazyInverse

Creates a lazy inverse transformation from an existing lazy transformation.
class LazyInverse(LazyTransform)
LazyInverse wraps a lazy transformation and returns its inverse when called. You typically don’t need to instantiate this class directly; use the .inv property instead.

Constructor

def __init__(transform: LazyTransform)
transform
LazyTransform
required
A lazy transformation y=f(xc)y = f(x | c) to invert.

Methods

forward

def forward(c: Tensor | None = None) -> Transform
Returns the inverse of the wrapped transformation.
c
Tensor | None
A context tensor passed to the underlying transformation.
Returns: The inverse transformation x=f1(yc)x = f^{-1}(y | c).

Properties

inv

@property
def inv(self) -> LazyTransform
Returns the original (non-inverted) lazy transformation. Returns: The wrapped transformation.

Example

import torch
from zuko.flows.autoregressive import MaskedAutoregressiveTransform

# Create a lazy transformation
forward_transform = MaskedAutoregressiveTransform(
    features=3,
    context=5,
    hidden_features=(64, 64)
)

# Get the inverse using the .inv property
inverse_transform = forward_transform.inv

# Use both directions
context = torch.randn(10, 5)
x = torch.randn(10, 3)

# Forward: x -> y
f = forward_transform(context)
y = f(x)

# Inverse: y -> x
f_inv = inverse_transform(context)
x_reconstructed = f_inv(y)

print(torch.allclose(x, x_reconstructed))  # True (within numerical precision)

# Double inverse returns the original
assert forward_transform.inv.inv is forward_transform
The .inv property creates a LazyInverse wrapper. Accessing .inv.inv returns the original transformation, not a double-wrapped version.

LazyComposedTransform

Creates a lazy composed transformation from a sequence of lazy transformations.
class LazyComposedTransform(LazyTransform)
This class composes multiple lazy transformations into a single lazy transformation y=fnf0(xc)y = f_n \circ \dots \circ f_0(x | c).
You rarely need to use LazyComposedTransform directly. The Flow class automatically creates a composed transformation when you pass a list of transformations.

Constructor

def __init__(*transforms: LazyTransform)
transforms
LazyTransform
required
A sequence of lazy transformations fif_i to compose.

Methods

forward

def forward(c: Tensor | None = None) -> Transform
Returns the composed transformation.
c
Tensor | None
A context tensor passed to all component transformations.
Returns: A ComposedTransform representing y=fnf0(xc)y = f_n \circ \dots \circ f_0(x | c).

Example

import torch
from zuko.lazy import LazyComposedTransform, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.transforms import RotationTransform

# Create multiple transformations
t1 = MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64))
t2 = UnconditionalTransform(RotationTransform, torch.randn(3, 3))
t3 = MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64))

# Compose them manually
composed = LazyComposedTransform(t1, t2, t3)

# Apply the composition
context = torch.randn(10, 5)
x = torch.randn(10, 3)

f = composed(context)
y = f(x)  # y = t3(t2(t1(x|c)|c)|c)

# Or use Flow, which does this automatically
from zuko.lazy import Flow

flow = Flow(
    transform=[t1, t2, t3],  # Automatically composed
    base=my_base_distribution
)
When composing transformations, the order matters. The composition fnf0f_n \circ \dots \circ f_0 means f0f_0 is applied first, then f1f_1, and so on.

See Also

  • LazyDistribution - The distribution counterpart to lazy transforms
  • Flow - Combines lazy transforms and distributions into normalizing flows
  • Transforms - Built-in transformation implementations

Build docs developers (and LLMs) love