Skip to main content
This tutorial walks you through implementing a variational autoencoder (VAE) for the MNIST dataset with a normalizing flow as prior.

Setup

import torch
import torch.nn as nn
import torch.utils.data as data

from torch import Tensor
from torch.distributions import Bernoulli, Distribution, Independent, Normal
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm

import zuko

_ = torch.random.manual_seed(0)

Data

The MNIST dataset consists of 28 x 28 grayscale images representing handwritten digits (0 to 9).
trainset = MNIST(root="", download=True, train=True, transform=to_tensor)
trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)
Visualize some samples:
x = [trainset[i][0] for i in range(16)]
x = torch.cat(x, dim=-1)

to_pil_image(x)

Evidence Lower Bound (ELBO)

As usual with variational inference, we wish to find the parameters ϕ\phi for which a model pϕ(x)p_\phi(x) is most similar to a target distribution p(x)p(x), which leads to the objective: argmaxϕEp(x)[logpϕ(x)]\arg \max_\phi \mathbb{E}_{p(x)} \big[ \log p_\phi(x) \big] However, variational autoencoders have latent random variables zz and model the joint distribution of zz and xx as a factorization: pϕ(x,z)=pϕ(xz)pϕ(z)p_\phi(x, z) = p_\phi(x | z) \, p_\phi(z) Where pϕ(xz)p_\phi(x | z) is the decoder (sometimes called likelihood) and pϕ(z)p_\phi(z) the prior. In this case, maximizing the log-evidence logpϕ(x)\log p_\phi(x) becomes an issue as the integral: pϕ(x)=pϕ(z,x)dzp_\phi(x) = \int p_\phi(z, x) \, \mathrm{d}z Is often intractable, not to mention its gradients. To solve this issue, VAEs introduce an encoder qψ(zx)q_\psi(z | x) (sometimes called proposal or guide) to define a lower bound for the evidence (ELBO) for which unbiased Monte Carlo estimates of the gradients are available. logpϕ(x)logpϕ(x)KL(qψ(zx)pϕ(zx))logpϕ(x)+Eqψ(zx)[logpϕ(zx)qψ(zx)]Eqψ(zx)[logpϕ(z,x)qψ(zx)]=ELBO(x,ϕ,ψ)\begin{align} \log p_\phi(x) & \geq \log p_\phi(x) - \mathrm{KL} \big( q_\psi(z | x) \, || \, p_\phi(z | x) \big) \\ & \geq \log p_\phi(x) + \mathbb{E}_{q_\psi(z | x)} \left[ \log \frac{p_\phi(z | x)}{q_\psi(z | x)} \right] \\ & \geq \mathbb{E}_{q_\psi(z | x)} \left[ \log \frac{p_\phi(z, x)}{q_\psi(z | x)} \right] = \mathrm{ELBO}(x, \phi, \psi) \end{align}
Importantly, if pϕ(x,z)p_\phi(x, z) and qψ(zx)q_\psi(z | x) are expressive enough, the bound can become tight and maximizing the ELBO for ϕ\phi and ψ\psi will lead to the same model as maximizing the log-evidence.

ELBO Implementation

class ELBO(nn.Module):
    def __init__(
        self,
        encoder: zuko.lazy.LazyDistribution,
        decoder: zuko.lazy.LazyDistribution,
        prior: zuko.lazy.LazyDistribution,
    ) -> None:
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.prior = prior

    def forward(self, x: Tensor) -> Tensor:
        q = self.encoder(x)
        z = q.rsample()

        return self.decoder(z).log_prob(x) + self.prior().log_prob(z) - q.log_prob(z)

Model

We choose a (diagonal) Gaussian model as encoder, a Bernoulli model as decoder, and a Masked Autoregressive Flow (MAF) as prior. We use 16 features for the latent space.

Encoder and Decoder

class GaussianModel(zuko.lazy.LazyDistribution):
    def __init__(self, features: int, context: int) -> None:
        super().__init__()

        self.hyper = nn.Sequential(
            nn.Linear(context, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2 * features),
        )

    def forward(self, c: Tensor) -> Distribution:
        phi = self.hyper(c)
        mu, log_sigma = phi.chunk(2, dim=-1)

        return Independent(Normal(mu, log_sigma.exp()), 1)


class BernoulliModel(zuko.lazy.LazyDistribution):
    def __init__(self, features: int, context: int) -> None:
        super().__init__()

        self.hyper = nn.Sequential(
            nn.Linear(context, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, features),
        )

    def forward(self, c: Tensor) -> Distribution:
        phi = self.hyper(c)
        rho = torch.sigmoid(phi)

        return Independent(Bernoulli(rho), 1)

Instantiate Models

encoder = GaussianModel(16, 784)
decoder = BernoulliModel(784, 16)

prior = zuko.flows.MAF(
    features=16,
    transforms=3,
    hidden_features=(256, 256),
)
Because the decoder is a Bernoulli model, the data xx should be binary.

Training

As explained earlier, our objective is to maximize the ELBO for all xx: argmaxϕ,ψEp(x)[ELBO(x,ϕ,ψ)]\arg \max_{\phi, \, \psi} \mathbb{E}_{p(x)} \big[ \text{ELBO}(x, \phi, \psi) \big]
elbo = ELBO(encoder, decoder, prior).cuda()
optimizer = torch.optim.Adam(elbo.parameters(), lr=1e-3)

for epoch in (bar := tqdm(range(64))):
    losses = []

    for x, _ in trainloader:
        x = x.round().flatten(-3).cuda()

        loss = -elbo(x).mean()
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.detach())

    losses = torch.stack(losses)

    bar.set_postfix(loss=losses.mean().item())
Training progress:
100%|██████████| 64/64 [07:09<00:00,  6.71s/it, loss=65.8]
The training takes about 7 minutes on a GPU. The loss should decrease significantly from the initial value.

Generating Images

After training, we can generate MNIST images by sampling latent variables from the prior and decoding them.
z = prior().sample((16,))
x = decoder(z).mean.reshape(-1, 28, 28)

to_pil_image(x.movedim(0, 1).reshape(28, -1))
The generated images should resemble MNIST digits, though they may not be perfect. Longer training or better hyperparameters can improve the quality.

Key Concepts

1

Encoder

Maps input images to a distribution over latent variables qψ(zx)q_\psi(z | x)
2

Decoder

Maps latent variables to a distribution over images pϕ(xz)p_\phi(x | z)
3

Prior

Defines the distribution of latent variables pϕ(z)p_\phi(z). Using a normalizing flow as prior makes it more expressive than a simple Gaussian.
4

ELBO

The evidence lower bound provides a tractable objective for training

Advantages of Flow-based Prior

Using a normalizing flow (MAF in this case) as the prior instead of a simple Gaussian distribution has several advantages:
  1. Expressivity: Flows can model complex, multi-modal distributions in the latent space
  2. Flexibility: The latent space can capture more intricate patterns in the data
  3. Better generation: More sophisticated priors lead to better quality generated samples
For more advanced VAE architectures, consider exploring hierarchical VAEs or using more powerful flows as encoders and decoders.

Build docs developers (and LLMs) love