This tutorial walks you through implementing a variational autoencoder (VAE) for the MNIST dataset with a normalizing flow as prior.
Setup
import torch
import torch.nn as nn
import torch.utils.data as data
from torch import Tensor
from torch.distributions import Bernoulli, Distribution, Independent, Normal
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm
import zuko
_ = torch.random.manual_seed(0)
Data
The MNIST dataset consists of 28 x 28 grayscale images representing handwritten digits (0 to 9).
trainset = MNIST(root="", download=True, train=True, transform=to_tensor)
trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)
Visualize some samples:
x = [trainset[i][0] for i in range(16)]
x = torch.cat(x, dim=-1)
to_pil_image(x)
Evidence Lower Bound (ELBO)
As usual with variational inference, we wish to find the parameters ϕ for which a model pϕ(x) is most similar to a target distribution p(x), which leads to the objective:
argmaxϕEp(x)[logpϕ(x)]
However, variational autoencoders have latent random variables z and model the joint distribution of z and x as a factorization:
pϕ(x,z)=pϕ(x∣z)pϕ(z)
Where pϕ(x∣z) is the decoder (sometimes called likelihood) and pϕ(z) the prior. In this case, maximizing the log-evidence logpϕ(x) becomes an issue as the integral:
pϕ(x)=∫pϕ(z,x)dz
Is often intractable, not to mention its gradients. To solve this issue, VAEs introduce an encoder qψ(z∣x) (sometimes called proposal or guide) to define a lower bound for the evidence (ELBO) for which unbiased Monte Carlo estimates of the gradients are available.
logpϕ(x)≥logpϕ(x)−KL(qψ(z∣x)∣∣pϕ(z∣x))≥logpϕ(x)+Eqψ(z∣x)[logqψ(z∣x)pϕ(z∣x)]≥Eqψ(z∣x)[logqψ(z∣x)pϕ(z,x)]=ELBO(x,ϕ,ψ)
Importantly, if pϕ(x,z) and qψ(z∣x) are expressive enough, the bound can become tight and maximizing the ELBO for ϕ and ψ will lead to the same model as maximizing the log-evidence.
ELBO Implementation
class ELBO(nn.Module):
def __init__(
self,
encoder: zuko.lazy.LazyDistribution,
decoder: zuko.lazy.LazyDistribution,
prior: zuko.lazy.LazyDistribution,
) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.prior = prior
def forward(self, x: Tensor) -> Tensor:
q = self.encoder(x)
z = q.rsample()
return self.decoder(z).log_prob(x) + self.prior().log_prob(z) - q.log_prob(z)
Model
We choose a (diagonal) Gaussian model as encoder, a Bernoulli model as decoder, and a Masked Autoregressive Flow (MAF) as prior. We use 16 features for the latent space.
Encoder and Decoder
class GaussianModel(zuko.lazy.LazyDistribution):
def __init__(self, features: int, context: int) -> None:
super().__init__()
self.hyper = nn.Sequential(
nn.Linear(context, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, 2 * features),
)
def forward(self, c: Tensor) -> Distribution:
phi = self.hyper(c)
mu, log_sigma = phi.chunk(2, dim=-1)
return Independent(Normal(mu, log_sigma.exp()), 1)
class BernoulliModel(zuko.lazy.LazyDistribution):
def __init__(self, features: int, context: int) -> None:
super().__init__()
self.hyper = nn.Sequential(
nn.Linear(context, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, features),
)
def forward(self, c: Tensor) -> Distribution:
phi = self.hyper(c)
rho = torch.sigmoid(phi)
return Independent(Bernoulli(rho), 1)
Instantiate Models
encoder = GaussianModel(16, 784)
decoder = BernoulliModel(784, 16)
prior = zuko.flows.MAF(
features=16,
transforms=3,
hidden_features=(256, 256),
)
Because the decoder is a Bernoulli model, the data x should be binary.
Training
As explained earlier, our objective is to maximize the ELBO for all x:
argϕ,ψmaxEp(x)[ELBO(x,ϕ,ψ)]
elbo = ELBO(encoder, decoder, prior).cuda()
optimizer = torch.optim.Adam(elbo.parameters(), lr=1e-3)
for epoch in (bar := tqdm(range(64))):
losses = []
for x, _ in trainloader:
x = x.round().flatten(-3).cuda()
loss = -elbo(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
losses.append(loss.detach())
losses = torch.stack(losses)
bar.set_postfix(loss=losses.mean().item())
Training progress:
100%|██████████| 64/64 [07:09<00:00, 6.71s/it, loss=65.8]
The training takes about 7 minutes on a GPU. The loss should decrease significantly from the initial value.
Generating Images
After training, we can generate MNIST images by sampling latent variables from the prior and decoding them.
z = prior().sample((16,))
x = decoder(z).mean.reshape(-1, 28, 28)
to_pil_image(x.movedim(0, 1).reshape(28, -1))
The generated images should resemble MNIST digits, though they may not be perfect. Longer training or better hyperparameters can improve the quality.
Key Concepts
Encoder
Maps input images to a distribution over latent variables qψ(z∣x) Decoder
Maps latent variables to a distribution over images pϕ(x∣z) Prior
Defines the distribution of latent variables pϕ(z). Using a normalizing flow as prior makes it more expressive than a simple Gaussian. ELBO
The evidence lower bound provides a tractable objective for training
Advantages of Flow-based Prior
Using a normalizing flow (MAF in this case) as the prior instead of a simple Gaussian distribution has several advantages:
- Expressivity: Flows can model complex, multi-modal distributions in the latent space
- Flexibility: The latent space can capture more intricate patterns in the data
- Better generation: More sophisticated priors lead to better quality generated samples
For more advanced VAE architectures, consider exploring hierarchical VAEs or using more powerful flows as encoders and decoders.