zuko.utils module provides utility functions for broadcasting, root-finding, numerical integration, and ODE solving. These utilities support gradient propagation through implicit differentiation and automatic differentiation.
broadcast
Broadcasts tensors together while preserving specified trailing dimensions. The term broadcasting describes how PyTorch treats tensors with different shapes during arithmetic operations. This function extends broadcasting to selectively preserve trailing dimensions.The tensors to broadcast
The number(s) of trailing dimensions not to broadcast
Example
bisection
Applies the bisection method to find between bounds and such that is close to . Gradients are propagated through and via implicit differentiation.A univariate function
The target value
The lower bound such that
The upper bound such that
The number of bisection iterations
The parameters of
Example
The bisection method requires that is continuous and . Gradients are computed via implicit differentiation without storing intermediate bisection states.
gauss_legendre
Estimates the definite integral of a function from to using an -point Gauss-Legendre quadrature:A univariate function
The lower integration limit
The upper integration limit
The number of quadrature points
The parameters of
Example
Higher values of
n provide more accurate integration at the cost of more function evaluations. For smooth functions, n=16 typically provides good accuracy.odeint
Integrates a system of first-order ordinary differential equations (ODEs): from to using the adaptive Dormand-Prince method. The output is the final state: Gradients are propagated through , , , and via the adaptive checkpoint adjoint (ACA) method.A system of first-order ODEs
The initial state
The initial integration time
The final integration time
The parameters of
The absolute tolerance for adaptive stepping
The relative tolerance for adaptive stepping
Example
Use Cases
ODE integration is used for:- Continuous normalizing flows (CNF): Where transformations are defined via ODEs
- Neural ODEs: Parameterizing transformations with neural networks
- Time-continuous dynamics: Modeling physical systems
References:
- Neural Ordinary Differential Equations (Chen et al., 2018) - arxiv.org/abs/1806.07366
- Adaptive Checkpoint Adjoint Method (Zhuang et al., 2020) - arxiv.org/abs/2006.02493
Partial
A version offunctools.partial that is a torch.nn.Module, allowing partial application of functions with automatic parameter/buffer registration.
An arbitrary callable. If
f is a module, it is registered as a submodule*args
The positional arguments passed to
fWhether tensor arguments are registered as buffers (True) or parameters (False)
**kwargs
The keyword arguments passed to
fExample
unpack
Unpacks a flattened tensor into multiple tensors with specified shapes.A packed tensor with shape
A sequence of shapes corresponding to the total number of elements
Example
This utility is commonly used in normalizing flows where multiple parameters (means, scales, etc.) are predicted by a single neural network and need to be unpacked into their respective shapes.
