Skip to main content
The zuko.utils module provides utility functions for broadcasting, root-finding, numerical integration, and ODE solving. These utilities support gradient propagation through implicit differentiation and automatic differentiation.

broadcast

Broadcasts tensors together while preserving specified trailing dimensions. The term broadcasting describes how PyTorch treats tensors with different shapes during arithmetic operations. This function extends broadcasting to selectively preserve trailing dimensions.
*tensors
Tensor
required
The tensors to broadcast
ignore
int | Sequence[int]
default:"0"
The number(s) of trailing dimensions not to broadcast
Returns: The broadcasted tensors as a list

Example

import torch
from zuko.utils import broadcast

# Broadcast with preserved dimensions
x = torch.rand(3, 1, 2)
y = torch.rand(4, 5)

x, y = broadcast(x, y, ignore=1)
print(x.shape)  # torch.Size([3, 4, 2])
print(y.shape)  # torch.Size([3, 4, 5])

# The last dimension of each tensor is preserved
# The leading dimensions are broadcasted together

bisection

Applies the bisection method to find xx between bounds aa and bb such that fϕ(x)f_\phi(x) is close to yy. Gradients are propagated through yy and ϕ\phi via implicit differentiation.
f
Callable[[Tensor], Tensor]
required
A univariate function fϕf_\phi
y
Tensor
required
The target value yy
a
float | Tensor
required
The lower bound such that fϕ(a)yf_\phi(a) \leq y
b
float | Tensor
required
The upper bound such that yfϕ(b)y \leq f_\phi(b)
n
int
default:"16"
The number of bisection iterations
phi
Iterable[Tensor]
default:"()"
The parameters ϕ\phi of fϕf_\phi
Returns: The solution xx

Example

import torch
from zuko.utils import bisection

# Find x such that cos(x) = 0
f = torch.cos
y = torch.tensor(0.0)
x = bisection(f, y, 2.0, 1.0, n=16)
print(x)  # tensor(1.5708) ≈ π/2

# With gradients
y = torch.tensor(0.5, requires_grad=True)
x = bisection(torch.cos, y, 0.0, torch.pi/2)
loss = x.sum()
loss.backward()
print(y.grad)  # Gradients flow through the root-finding
The bisection method requires that ff is continuous and f(a)yf(b)f(a) \leq y \leq f(b). Gradients are computed via implicit differentiation without storing intermediate bisection states.

gauss_legendre

Estimates the definite integral of a function fϕ(x)f_\phi(x) from aa to bb using an nn-point Gauss-Legendre quadrature: abfϕ(x)dx(ba)i=1nwifϕ(xi)\int_a^b f_\phi(x) \, dx \approx (b - a) \sum_{i = 1}^n w_i f_\phi(x_i)
f
Callable[[Tensor], Tensor]
required
A univariate function fϕf_\phi
a
Tensor
required
The lower integration limit aa
b
Tensor
required
The upper integration limit bb
n
int
default:"3"
The number of quadrature points
phi
Iterable[Tensor]
default:"()"
The parameters ϕ\phi of fϕf_\phi
Returns: The definite integral estimation

Example

import torch
from zuko.utils import gauss_legendre

# Integrate exp(-x^2) from -0.69 to 4.2
f = lambda x: torch.exp(-x**2)
a, b = torch.tensor([-0.69, 4.2])

result = gauss_legendre(f, a, b, n=16)
print(result)  # tensor(1.4807)

# With parametric function
w = torch.tensor(2.0, requires_grad=True)
f = lambda x: torch.exp(-w * x**2)
result = gauss_legendre(f, a, b, n=16, phi=(w,))
result.backward()
print(w.grad)  # Gradients flow through w
Higher values of n provide more accurate integration at the cost of more function evaluations. For smooth functions, n=16 typically provides good accuracy.

odeint

Integrates a system of first-order ordinary differential equations (ODEs): dxdt=fϕ(t,x)\frac{dx}{dt} = f_\phi(t, x) from t0t_0 to t1t_1 using the adaptive Dormand-Prince method. The output is the final state: x(t1)=x0+t0t1fϕ(t,x(t))dtx(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(t, x(t)) \, dt Gradients are propagated through x0x_0, t0t_0, t1t_1, and ϕ\phi via the adaptive checkpoint adjoint (ACA) method.
f
Callable[[Tensor, Tensor], Tensor]
required
A system of first-order ODEs fϕ(t,x)f_\phi(t, x)
x
Tensor | Sequence[Tensor]
required
The initial state x0x_0
t0
float | Tensor
required
The initial integration time t0t_0
t1
float | Tensor
required
The final integration time t1t_1
phi
Iterable[Tensor]
default:"()"
The parameters ϕ\phi of fϕf_\phi
atol
float
default:"1e-6"
The absolute tolerance for adaptive stepping
rtol
float
default:"1e-5"
The relative tolerance for adaptive stepping
Returns: The final state x(t1)x(t_1)

Example

import torch
from zuko.utils import odeint

# Solve dx/dt = x @ A from t=0 to t=1
A = torch.randn(3, 3)
f = lambda t, x: x @ A
x0 = torch.randn(3)

x1 = odeint(f, x0, 0.0, 1.0)
print(x1)  # tensor([-1.4596,  0.5008,  1.5828])

# With gradients
A = torch.randn(3, 3, requires_grad=True)
x0 = torch.randn(3, requires_grad=True)
f = lambda t, x: x @ A

x1 = odeint(f, x0, 0.0, 1.0, phi=(A,))
loss = x1.square().sum()
loss.backward()
print(A.grad)  # Gradients computed via adjoint method

Use Cases

ODE integration is used for:
  • Continuous normalizing flows (CNF): Where transformations are defined via ODEs
  • Neural ODEs: Parameterizing transformations with neural networks
  • Time-continuous dynamics: Modeling physical systems
from zuko.flows import CNF

# Continuous normalizing flow uses odeint internally
flow = CNF(features=3, context=2, hidden_features=[64, 64])
References:The adjoint method avoids storing intermediate integration states, making it memory-efficient for deep neural ODEs.

Partial

A version of functools.partial that is a torch.nn.Module, allowing partial application of functions with automatic parameter/buffer registration.
f
Callable
required
An arbitrary callable. If f is a module, it is registered as a submodule
*args
The positional arguments passed to f
buffer
bool
default:"False"
Whether tensor arguments are registered as buffers (True) or parameters (False)
**kwargs
The keyword arguments passed to f

Example

import torch
import torch.nn as nn
from zuko.utils import Partial

# Partial application with tensor
increment = Partial(torch.add, torch.tensor(1.0), buffer=True)
print(increment(torch.arange(3)))  # tensor([1., 2., 3.])

# Partial linear layer
weight = torch.randn(5, 3)
linear = Partial(torch.nn.functional.linear, weight=weight)
x = torch.rand(2, 3)
y = linear(x)  # shape: (2, 5)

# Partial distribution
f = torch.distributions.Normal
loc, scale = torch.zeros(3), torch.ones(3)
dist = Partial(f, loc, scale, buffer=True)
samples = dist().sample()  # Sample from Normal(0, 1)

unpack

Unpacks a flattened tensor into multiple tensors with specified shapes.
x
Tensor
required
A packed tensor with shape (,D)(*, D)
shapes
Sequence[Size]
required
A sequence of shapes SiS_i corresponding to the total number of elements DD
Returns: The unpacked tensors with shapes (,Si)(*, S_i)

Example

import torch
from zuko.utils import unpack

# Pack multiple tensors into one
x = torch.randn(26)

# Unpack into original shapes
y, z = unpack(x, ((1, 2, 3), (4, 5)))
print(y.shape)  # torch.Size([1, 2, 3])
print(z.shape)  # torch.Size([4, 5])

# With batch dimension
x = torch.randn(8, 26)
y, z = unpack(x, ((1, 2, 3), (4, 5)))
print(y.shape)  # torch.Size([8, 1, 2, 3])
print(z.shape)  # torch.Size([8, 4, 5])
This utility is commonly used in normalizing flows where multiple parameters (means, scales, etc.) are predicted by a single neural network and need to be unpacked into their respective shapes.

Build docs developers (and LLMs) love