Skip to main content

Overview

Continuous Normalizing Flow (CNF) uses ordinary differential equations (ODEs) to define continuous-time transformations. Instead of stacking discrete transformation layers, CNF learns a continuous dynamics function that transforms the base distribution into the target distribution.

References

Neural Ordinary Differential Equations (Chen et al., 2018)
https://arxiv.org/abs/1806.07366
FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018)
https://arxiv.org/abs/1810.01367

Class Definition

zuko.flows.CNF(
    features: int,
    context: int = 0,
    freqs: int = 3,
    atol: float = 1e-6,
    rtol: float = 1e-5,
    exact: bool = True,
    **kwargs
)

Parameters

features
int
required
The number of features in the data.
context
int
default:"0"
The number of context features for conditional density estimation.
freqs
int
default:"3"
The number of time embedding frequencies. Higher values provide richer time representations.
atol
float
default:"1e-6"
The absolute integration tolerance for the ODE solver. Lower values increase accuracy but slow computation.
rtol
float
default:"1e-5"
The relative integration tolerance for the ODE solver.
exact
bool
default:"True"
Whether to calculate the exact log-determinant of the Jacobian (True) or use an unbiased stochastic estimate (False). Exact is more accurate but slower.
**kwargs
dict
Additional keyword arguments passed to the MLP constructor:
  • hidden_features: Hidden layer sizes (default: [64, 64])
  • activation: Activation function (default: ELU)

Usage Example

import torch
import zuko

# Create an unconditional CNF
flow = zuko.flows.CNF(
    features=5,
    hidden_features=[128, 128],
    freqs=5
)

# Sample from the flow
dist = flow()
samples = dist.sample((1000,))
print(samples.shape)  # torch.Size([1000, 5])

# Compute log probabilities
log_prob = dist.log_prob(samples)
print(log_prob.shape)  # torch.Size([1000])

Conditional Flow

# Create a conditional CNF
flow = zuko.flows.CNF(
    features=3,
    context=5,
    freqs=5,
    hidden_features=[256, 256]
)

context = torch.randn(5)
dist = flow(context)
samples = dist.sample((100,))

Fast Training with Stochastic Trace

# Use stochastic trace estimation for faster training
flow = zuko.flows.CNF(
    features=20,
    exact=False,  # Stochastic trace estimation
    atol=1e-5,
    rtol=1e-4,
    hidden_features=[256, 256]
)

Training Example

import torch.optim as optim

flow = zuko.flows.CNF(
    features=10,
    freqs=5,
    hidden_features=[256, 256, 256]
)

optimizer = optim.Adam(flow.parameters(), lr=1e-3)

for epoch in range(100):
    for x in dataloader:
        optimizer.zero_grad()
        
        loss = -flow().log_prob(x).mean()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Methods

forward(c=None)

Returns a normalizing flow distribution. Arguments:
  • c (Tensor, optional): Context tensor of shape (*, context)
Returns:
  • NormalizingFlow: A distribution with:
    • sample(shape): Sample from the distribution
    • log_prob(x): Compute log probability of samples
    • rsample(shape): Reparameterized sampling

When to Use CNF

Good for:
  • Research and experimentation
  • Theoretically unlimited expressivity
  • Continuous-time modeling
  • When discrete layers are limiting
  • Irregular time series data
Consider alternatives if:
  • You need fast training (use MAF or NSF)
  • You need fast sampling (use RealNVP)
  • You have limited compute resources
  • You want simpler, more interpretable models

Tips

  1. Start with stochastic trace: Set exact=False for faster training, especially with high-dimensional data.
  2. Tune tolerances: Decrease atol and rtol for better accuracy, increase for faster computation.
  3. More time frequencies: Use freqs=5 or higher for complex temporal dynamics.
  4. Deep networks: CNF benefits from deeper networks (3-4 layers with 256+ hidden units).
  5. Use ELU activation: Default ELU works well for ODEs.

Architecture Details

CNF models continuous dynamics:
  • Base distribution: Diagonal Gaussian N(0, I)
  • ODE: dx/dt = f(x, t, c) where f is a neural network
  • Time embedding: Sinusoidal embeddings of time t
  • Integration: From t=0 to t=1 using adaptive ODE solvers
  • Log determinant: Computed via trace of Jacobian
The transformation is:
x(t=1) = x(t=0) + ∫[0 to 1] f(x(t), t, c) dt

Continuous vs. Discrete Flows

PropertyCNF (Continuous)MAF/NSF (Discrete)
LayersContinuous dynamicsDiscrete transformations
ExpressivityTheoretically unlimitedLimited by layers
Training speedSlowFast to medium
Sampling speedSlowSlow (MAF/NSF)
MemoryHigher (ODE solver)Lower
InterpretabilityDynamicsTransformations

ODE Solver Details

CNF uses adaptive ODE solvers:
  • Forward pass: Integrate from t=0 to t=1
  • Inverse pass: Integrate from t=1 to t=0 (reverse ODE)
  • Solver: Adaptive step-size Runge-Kutta methods
  • Gradients: Computed via adjoint method (memory efficient)

Tolerances

  • atol=1e-6, rtol=1e-5: High accuracy (default)
  • atol=1e-5, rtol=1e-4: Balanced
  • atol=1e-4, rtol=1e-3: Fast but less accurate

Trace Estimation

Exact (exact=True):
log_det = trace(df/dx)  # Exact computation
Stochastic (exact=False):
log_det = E[v^T (df/dx) v]  # Hutchinson's estimator
where v ~ N(0, I) is a random vector.

Time Embedding

CNF embeds time using sinusoidal features:
t_embed = [cos(k * pi * t), sin(k * pi * t) for k in 1..freqs]
This provides the network with information about where in the flow we are.

Advanced Usage

Custom Network Architecture

import torch.nn as nn

flow = zuko.flows.CNF(
    features=10,
    freqs=7,
    hidden_features=[512, 512, 512, 512],
    activation=nn.Tanh,
    atol=1e-7,
    rtol=1e-6
)

High-Dimensional Data

# For high dimensions, use stochastic trace
flow = zuko.flows.CNF(
    features=100,
    exact=False,  # Much faster for high dimensions
    freqs=5,
    hidden_features=[512, 512]
)

Manual Construction

from zuko.flows.continuous import FFJTransform
from zuko.distributions import DiagNormal
import torch

# Build CNF manually
transform = FFJTransform(
    features=10,
    context=5,
    freqs=5,
    atol=1e-6,
    rtol=1e-5,
    exact=True,
    hidden_features=[256, 256]
)

base = DiagNormal(
    loc=torch.zeros(10),
    scale=torch.ones(10)
)

Computational Considerations

CNF is computationally expensive:
  • Forward pass: Requires ODE integration
  • Backward pass: Uses adjoint method (memory efficient)
  • Function evaluations: Adaptive solver makes multiple network calls
  • Memory: Stores intermediate states during integration
Optimization strategies:
  1. Use exact=False for high-dimensional data
  2. Increase tolerances (atol=1e-5, rtol=1e-4)
  3. Use smaller networks
  4. Reduce time embedding frequencies
  5. Use mixed precision training

Applications

Time Series

CNF naturally handles irregular time series:
# Different integration times for different samples
flow = zuko.flows.CNF(features=10)

# Can model dynamics at arbitrary times

Continuous Processes

Model continuous physical processes:
flow = zuko.flows.CNF(
    features=3,  # Position in 3D
    context=6,   # Velocity and forces
    freqs=10
)

Research

Explore theoretical limits:
# CNF can approximate any diffeomorphism
flow = zuko.flows.CNF(
    features=data_dim,
    hidden_features=[1024, 1024, 1024],
    exact=True,
    atol=1e-8,
    rtol=1e-7
)

Comparison with Other Flows

PropertyCNFNAFNSFRealNVP
TypeContinuousNeuralSplineCoupling
TrainingVery slowSlowMediumFast
SamplingSlowSlowSlowFast
ExpressivityUnlimitedVery highHighMedium
MemoryHighHighMediumLow
Use caseResearchComplexGeneralProduction

Build docs developers (and LLMs) love