Overview
Lazy transforms are PyTorch modules that build and return a torch.distributions.Transform within their forward pass, given an optional context. They are the transformation counterpart to lazy distributions and are essential for building conditional normalizing flows.
Lazy transforms solve the same problem as lazy distributions: PyTorch’s Transform objects are not nn.Module instances. Lazy transforms wrap transformations as modules, enabling conditional transformations with trainable parameters.
The abstract base class for all lazy transformations.
class LazyTransform(nn.Module, abc.ABC)
A lazy transformation builds and returns a transformation y=f(x∣c) within its forward pass, given a context c.
Methods
forward
def forward(c: Tensor | None = None) -> Transform
Builds and returns a conditional transformation.
A context tensor. If None, the transformation is unconditional.
Returns: A torch.distributions.Transform object representing y=f(x∣c).
Properties
inv
@property
def inv(self) -> LazyTransform
Returns a lazy inverse transformation x=f−1(y∣c).
Returns: A LazyInverse instance that wraps this transformation.
To create a custom lazy transform, subclass LazyTransform and implement the forward method:
import torch
import torch.nn as nn
from torch.distributions.transforms import AffineTransform
from zuko.lazy import LazyTransform
class ConditionalAffine(LazyTransform):
def __init__(self, features: int, context: int):
super().__init__()
self.scale_net = nn.Sequential(
nn.Linear(context, features),
nn.Softplus()
)
self.shift_net = nn.Linear(context, features)
def forward(self, c: torch.Tensor | None = None) -> AffineTransform:
if c is None:
raise ValueError("Context is required")
scale = self.scale_net(c)
shift = self.shift_net(c)
return AffineTransform(loc=shift, scale=scale)
# Usage
transform = ConditionalAffine(features=3, context=5)
context = torch.randn(10, 5) # Batch of 10 contexts
f = transform(context) # Returns an AffineTransform
x = torch.randn(10, 3)
y = f(x) # Apply transformation: y = f(x|c)
x_reconstructed = f.inv(y) # Invert: x = f^{-1}(y|c)
A convenience class for creating unconditional lazy transforms from transformation constructors.
class UnconditionalTransform(LazyTransform)
UnconditionalTransform wraps any transformation constructor and registers its arguments as buffers or parameters.
Constructor
def __init__(
f: Callable[..., Transform],
*args,
buffer: bool = False,
**kwargs
)
f
Callable[..., Transform]
required
A transformation constructor (e.g., torch.distributions.transforms.ExpTransform or a custom transform class). If f is a module, it is registered as a submodule.
Positional arguments passed to f. Tensor arguments are registered as buffers or parameters.
Whether tensor arguments are registered as buffers (not trainable) or parameters (trainable).
Keyword arguments passed to f. Tensor arguments are registered as buffers or parameters.
Methods
forward
def forward(c: Tensor | None = None) -> Transform
Returns the transformation by calling f(*args, **kwargs). The context argument c is always ignored.
A context tensor. This argument is ignored for unconditional transformations.
Returns: f(*args, **kwargs) - the constructed transformation.
Examples
import torch
from zuko.lazy import UnconditionalTransform
from zuko.transforms import ExpTransform
# Create an unconditional exponential transformation
t = UnconditionalTransform(ExpTransform)
# Get the transformation (context is ignored)
transform = t()
print(transform) # ExpTransform()
# Apply the transformation
x = torch.randn(3)
y = transform(x)
print(y) # tensor([4.6692, 0.7457, 0.1132])
import torch
from zuko.lazy import UnconditionalTransform
from zuko.transforms import RotationTransform
# Create a rotation with a trainable rotation matrix
rotation_matrix = torch.randn(3, 3)
t = UnconditionalTransform(
RotationTransform,
rotation_matrix,
buffer=False # Make the matrix trainable
)
# Use in a flow
from zuko.lazy import Flow
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
flow = Flow(
transform=[
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
t, # Unconditional rotation
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
],
base=my_base_distribution
)
import torch
from torch.distributions.transforms import AffineTransform
from zuko.lazy import UnconditionalTransform
# Create a fixed affine transformation (not trainable)
loc = torch.zeros(5)
scale = torch.ones(5)
t = UnconditionalTransform(
AffineTransform,
loc=loc,
scale=scale,
buffer=True # Fixed parameters
)
# The transformation is now part of a module
transform = t()
x = torch.randn(10, 5)
y = transform(x) # y = scale * x + loc
When buffer=True, tensor arguments are registered as buffers (not trainable). When buffer=False, they become trainable parameters. Choose based on whether you want the transformation to be learned during training.
LazyInverse
Creates a lazy inverse transformation from an existing lazy transformation.
class LazyInverse(LazyTransform)
LazyInverse wraps a lazy transformation and returns its inverse when called. You typically don’t need to instantiate this class directly; use the .inv property instead.
Constructor
def __init__(transform: LazyTransform)
A lazy transformation y=f(x∣c) to invert.
Methods
forward
def forward(c: Tensor | None = None) -> Transform
Returns the inverse of the wrapped transformation.
A context tensor passed to the underlying transformation.
Returns: The inverse transformation x=f−1(y∣c).
Properties
inv
@property
def inv(self) -> LazyTransform
Returns the original (non-inverted) lazy transformation.
Returns: The wrapped transformation.
Example
import torch
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
# Create a lazy transformation
forward_transform = MaskedAutoregressiveTransform(
features=3,
context=5,
hidden_features=(64, 64)
)
# Get the inverse using the .inv property
inverse_transform = forward_transform.inv
# Use both directions
context = torch.randn(10, 5)
x = torch.randn(10, 3)
# Forward: x -> y
f = forward_transform(context)
y = f(x)
# Inverse: y -> x
f_inv = inverse_transform(context)
x_reconstructed = f_inv(y)
print(torch.allclose(x, x_reconstructed)) # True (within numerical precision)
# Double inverse returns the original
assert forward_transform.inv.inv is forward_transform
The .inv property creates a LazyInverse wrapper. Accessing .inv.inv returns the original transformation, not a double-wrapped version.
Creates a lazy composed transformation from a sequence of lazy transformations.
class LazyComposedTransform(LazyTransform)
This class composes multiple lazy transformations into a single lazy transformation y=fn∘⋯∘f0(x∣c).
You rarely need to use LazyComposedTransform directly. The Flow class automatically creates a composed transformation when you pass a list of transformations.
Constructor
def __init__(*transforms: LazyTransform)
A sequence of lazy transformations fi to compose.
Methods
forward
def forward(c: Tensor | None = None) -> Transform
Returns the composed transformation.
A context tensor passed to all component transformations.
Returns: A ComposedTransform representing y=fn∘⋯∘f0(x∣c).
Example
import torch
from zuko.lazy import LazyComposedTransform, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.transforms import RotationTransform
# Create multiple transformations
t1 = MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64))
t2 = UnconditionalTransform(RotationTransform, torch.randn(3, 3))
t3 = MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64))
# Compose them manually
composed = LazyComposedTransform(t1, t2, t3)
# Apply the composition
context = torch.randn(10, 5)
x = torch.randn(10, 3)
f = composed(context)
y = f(x) # y = t3(t2(t1(x|c)|c)|c)
# Or use Flow, which does this automatically
from zuko.lazy import Flow
flow = Flow(
transform=[t1, t2, t3], # Automatically composed
base=my_base_distribution
)
When composing transformations, the order matters. The composition fn∘⋯∘f0 means f0 is applied first, then f1, and so on.
See Also
- LazyDistribution - The distribution counterpart to lazy transforms
- Flow - Combines lazy transforms and distributions into normalizing flows
- Transforms - Built-in transformation implementations