Skip to main content
This tutorial walks you through training a normalizing flow by gradient descent when data is unavailable, but an energy function U(x)U(x) proportional to the density p(x)p(x) is available.

Setup

import matplotlib.pyplot as plt
import torch

from torch import Tensor

import zuko

_ = torch.random.manual_seed(0)

Energy Function

We consider a simple multi-modal energy function: logU(x)=sin(πx1)2(x12+x222)2\log U(x) = \sin(\pi x_1) - 2 \big( x_1^2 + x_2^2 - 2 \big)^2
def log_energy(x: Tensor) -> Tensor:
    x1, x2 = x[..., 0], x[..., 1]
    return torch.sin(torch.pi * x1) - 2 * (x1**2 + x2**2 - 2) ** 2

Visualizing the Energy

Let’s visualize the energy landscape:
x1 = torch.linspace(-3, 3, 64)
x2 = torch.linspace(-3, 3, 64)

x = torch.stack(torch.meshgrid(x1, x2, indexing="xy"), dim=-1)

energy = log_energy(x).exp()
plt.figure(figsize=(4.8, 4.8))
plt.imshow(energy)
plt.show()
The energy function has multiple modes, making it an interesting target distribution.

Flow

We use a neural spline flow (NSF) as density estimator qϕ(x)q_\phi(x). However, we invert the transformation(s), which makes sampling more efficient as the inverse call of an autoregressive transformation is DD (where DD is the number of features) times slower than its forward call.
flow = zuko.flows.NSF(features=2, transforms=3, hidden_features=(64, 64))
flow = zuko.flows.Flow(flow.transform.inv, flow.base)
flow
Inverting the transformation makes sampling more efficient but may make likelihood evaluation slower. This is a good trade-off when training with the reverse KL divergence.

Objective

The objective is to minimize the Kullback-Leibler (KL) divergence between the modeled distribution qϕ(x)q_\phi(x) and the true data distribution p(x)p(x). argminϕ KL(qϕ(x)p(x))=argminϕ Eqϕ(x)[logqϕ(x)p(x)]=argminϕ Eqϕ(x)[logqϕ(x)logU(x)] \begin{align} \arg \min_\phi & ~ \mathrm{KL} \big( q_\phi(x) || p(x) \big) \\ = \arg \min_\phi & ~ \mathbb{E}_{q_\phi(x)} \left[ \log \frac{q_\phi(x)}{p(x)} \right] \\ = \arg \min_\phi & ~ \mathbb{E}_{q_\phi(x)} \big[ \log q_\phi(x) - \log U(x) \big] \end{align}
Note that this “reverse KL” objective is prone to mode collapses, especially for high-dimensional data.

Training

optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

for epoch in range(8):
    losses = []

    for _ in range(256):
        x, log_prob = flow().rsample_and_log_prob((256,))  # faster than rsample + log_prob

        loss = log_prob.mean() - log_energy(x).mean()
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.detach())

    losses = torch.stack(losses)

    print(f"({epoch})", losses.mean().item(), "±", losses.std().item())
Training output:
(0) -0.8076444268226624 ± 1.4574381113052368
(1) -1.5428426265716553 ± 0.1310734897851944
(2) -1.5719032287597656 ± 0.04953986406326294
(3) -1.5784311294555664 ± 0.022914621978998184
(4) -1.5850317478179932 ± 0.023105977103114128
(5) -1.5861082077026367 ± 0.022541742771863937
(6) -1.5803889036178589 ± 0.14612749218940735
(7) -1.5888274908065796 ± 0.017613010480999947
The rsample_and_log_prob method is more efficient than calling rsample and log_prob separately.

Sampling

After training, we can sample from the learned distribution:
samples = flow().sample((16384,))

plt.figure(figsize=(4.8, 4.8))
plt.hist2d(*samples.T, bins=64, range=((-3, 3), (-3, 3)))
plt.show()
The generated samples should capture the multi-modal structure of the energy function.

Key Differences

The key differences between forward KL (train from data) and reverse KL (train from energy) are:
1

Data vs Energy

Forward KL requires access to data samples, while reverse KL only needs an energy function.
2

Mode Coverage

Forward KL tends to cover all modes but may spread probability mass thinly. Reverse KL may miss some modes but concentrates on the main ones.
3

Sampling Efficiency

Reverse KL benefits from inverting transformations to make sampling more efficient.

Build docs developers (and LLMs) love