What is Zuko?
Zuko is a Python package that implements normalizing flows in PyTorch. Normalizing flows are a powerful class of generative models that learn complex probability distributions by transforming simple base distributions through a sequence of invertible transformations.Why Zuko?
PyTorch provides excellentDistribution and Transform classes for probabilistic programming. However, these classes have significant limitations when building normalizing flows:
- Not GPU-compatible:
DistributionandTransformare not subclasses oftorch.nn.Module, which means you cannot send their internal tensors to GPU with.to('cuda')or retrieve their parameters with.parameters() - No conditional distributions: The concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express in standard PyTorch
- Limited trainability: Without being modules, these classes cannot easily participate in gradient-based optimization
Core Concepts
LazyDistribution
ALazyDistribution is any torch.nn.Module whose forward pass returns a Distribution. This design allows the distribution creation to be delayed until a condition is provided.
- Conditional distributions
p(x | c)where the distribution parameters depend on contextc - Trainable parameters that can be optimized with standard PyTorch optimizers
- GPU compatibility through standard
.to('cuda')calls
LazyTransform
Similarly, aLazyTransform is any torch.nn.Module whose forward pass returns a Transform. This allows for conditional transformations where the transformation parameters depend on context.
Normalizing Flows
A normalizing flow in Zuko is built by combining:- A sequence of
LazyTransformobjects (the invertible transformations) - A
LazyDistributionbase distribution (typically a simple Gaussian)
c to the flow, it returns a conditional distribution p(x | c) that can be:
- Evaluated: Compute
log_prob(x)for density estimation - Sampled: Generate samples
x ~ p(x | c) - Optimized: Train the flow parameters to maximize likelihood
PyTorch Integration Benefits
By building on PyTorch’s ecosystem, Zuko provides:Seamless Training
Use standard PyTorch optimizers, learning rate schedulers, and training loops
GPU Acceleration
Move flows to GPU with
.to('cuda') just like any other PyTorch moduleAutomatic Differentiation
Leverage PyTorch’s autograd for efficient gradient computation
Composability
Combine flows with other PyTorch modules in larger architectures
Key Features
- 12+ Pre-built Flows: Including NSF, MAF, RealNVP, CNF, and more modern architectures
- Conditional Modeling: Built-in support for context-dependent distributions
- Custom Flows: Easy-to-understand API for building your own flow architectures
- Type Safety: Full type hints for better IDE support and code quality
- Research-Friendly: Clean implementations that are easy to understand and extend
Design Philosophy
Zuko’s design prioritizes:- Simplicity: The lazy distribution/transform pattern is easy to understand
- Flexibility: Build custom flows or use pre-built architectures
- Correctness: Implementations closely follow the original papers
- Performance: Leverages PyTorch’s optimized operations
What’s Next?
Installation
Install Zuko and its dependencies
Quickstart
Train your first normalizing flow in minutes
