This guide shows you how to integrate normalizing flows into Variational Autoencoders (VAEs) to create more expressive generative models.
VAE Fundamentals
A Variational Autoencoder consists of three components:
- Encoder qψ(z∣x): Maps data to latent space
- Decoder pϕ(x∣z): Reconstructs data from latent codes
- Prior pϕ(z): Distribution over latent codes
The training objective is to maximize the Evidence Lower Bound (ELBO):
ELBO(x)=Eqψ(z∣x)[logpϕ(x∣z)+logpϕ(z)−logqψ(z∣x)]
Why Use Flow Priors?
Standard VAEs use a simple Gaussian prior p(z)=N(0,I). Normalizing flows provide a more expressive prior, allowing the model to learn complex latent structures.
Implementing the ELBO
First, let’s implement a general ELBO module:
import torch
import torch.nn as nn
from torch import Tensor
import zuko
class ELBO(nn.Module):
def __init__(
self,
encoder: zuko.lazy.LazyDistribution,
decoder: zuko.lazy.LazyDistribution,
prior: zuko.lazy.LazyDistribution,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.prior = prior
def forward(self, x: Tensor) -> Tensor:
# Encode
q = self.encoder(x)
z = q.rsample() # Reparameterization trick
# ELBO = E[log p(x|z)] + E[log p(z)] - E[log q(z|x)]
log_p_x_given_z = self.decoder(z).log_prob(x)
log_p_z = self.prior().log_prob(z)
log_q_z_given_x = q.log_prob(z)
return log_p_x_given_z + log_p_z - log_q_z_given_x
The encoder uses rsample() for the reparameterization trick, enabling gradients to flow through the sampling operation.
Building VAE Components
Gaussian Encoder
A diagonal Gaussian encoder:
from torch.distributions import Independent, Normal
class GaussianEncoder(zuko.lazy.LazyDistribution):
def __init__(self, features: int, context: int):
super().__init__()
self.hyper = nn.Sequential(
nn.Linear(context, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, 2 * features)
)
def forward(self, c: Tensor):
params = self.hyper(c)
mu, log_sigma = params.chunk(2, dim=-1)
return Independent(Normal(mu, log_sigma.exp()), 1)
Bernoulli Decoder
For binary data (e.g., MNIST):
from torch.distributions import Independent, Bernoulli
class BernoulliDecoder(zuko.lazy.LazyDistribution):
def __init__(self, features: int, context: int):
super().__init__()
self.hyper = nn.Sequential(
nn.Linear(context, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, features)
)
def forward(self, c: Tensor):
logits = self.hyper(c)
probs = torch.sigmoid(logits)
return Independent(Bernoulli(probs), 1)
Gaussian Decoder
For continuous data:
class GaussianDecoder(zuko.lazy.LazyDistribution):
def __init__(self, features: int, context: int):
super().__init__()
self.hyper = nn.Sequential(
nn.Linear(context, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, 2 * features)
)
def forward(self, c: Tensor):
params = self.hyper(c)
mu, log_sigma = params.chunk(2, dim=-1)
return Independent(Normal(mu, log_sigma.exp()), 1)
VAE with Flow Prior
Now we can build a VAE with a normalizing flow prior:
import zuko.flows
# Define dimensions
latent_dim = 16
data_dim = 784 # e.g., flattened 28x28 MNIST images
# Encoder: q(z|x)
encoder = GaussianEncoder(features=latent_dim, context=data_dim)
# Decoder: p(x|z)
decoder = BernoulliDecoder(features=data_dim, context=latent_dim)
# Prior: p(z) - Normalizing Flow
prior = zuko.flows.MAF(
features=latent_dim,
context=0, # Unconditional
transforms=3,
hidden_features=(256, 256)
)
# Complete VAE
vae = ELBO(encoder, decoder, prior)
For the prior, MAF (Masked Autoregressive Flow) works well for moderate latent dimensions (up to ~50). For higher dimensions, consider NSF (Neural Spline Flow).
Training the VAE
Prepare data
import torch.utils.data as data
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor
trainset = MNIST(
root="./data",
download=True,
train=True,
transform=to_tensor
)
trainloader = data.DataLoader(
trainset,
batch_size=256,
shuffle=True
)
Training loop
vae = vae.cuda()
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
for epoch in range(64):
epoch_losses = []
for x, _ in trainloader:
# Prepare data (binarize and flatten MNIST)
x = x.round().flatten(-3).cuda()
# Compute negative ELBO
elbo = vae(x)
loss = -elbo.mean()
# Optimize
loss.backward()
optimizer.step()
optimizer.zero_grad()
epoch_losses.append(loss.item())
avg_loss = sum(epoch_losses) / len(epoch_losses)
print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
Generate samples
After training, generate new samples:vae.eval()
with torch.no_grad():
# Sample from prior
z = vae.prior().sample((16,))
# Decode to data space
x_generated = vae.decoder(z).mean
# Reshape for visualization (if using images)
images = x_generated.reshape(-1, 28, 28)
Complete MNIST Example
Here’s a complete working example for MNIST:
import torch
import torch.nn as nn
import torch.utils.data as data
from torch import Tensor
from torch.distributions import Bernoulli, Independent, Normal
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor
import zuko
# ELBO Module
class ELBO(nn.Module):
def __init__(self, encoder, decoder, prior):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.prior = prior
def forward(self, x: Tensor) -> Tensor:
q = self.encoder(x)
z = q.rsample()
return self.decoder(z).log_prob(x) + self.prior().log_prob(z) - q.log_prob(z)
# Encoder
class GaussianEncoder(zuko.lazy.LazyDistribution):
def __init__(self, features: int, context: int):
super().__init__()
self.hyper = nn.Sequential(
nn.Linear(context, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, 2 * features),
)
def forward(self, c: Tensor):
mu, log_sigma = self.hyper(c).chunk(2, dim=-1)
return Independent(Normal(mu, log_sigma.exp()), 1)
# Decoder
class BernoulliDecoder(zuko.lazy.LazyDistribution):
def __init__(self, features: int, context: int):
super().__init__()
self.hyper = nn.Sequential(
nn.Linear(context, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, features),
)
def forward(self, c: Tensor):
return Independent(Bernoulli(logits=self.hyper(c)), 1)
# Data
trainset = MNIST(root="./data", download=True, train=True, transform=to_tensor)
trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)
# Model
encoder = GaussianEncoder(16, 784)
decoder = BernoulliDecoder(784, 16)
prior = zuko.flows.MAF(features=16, transforms=3, hidden_features=(256, 256))
vae = ELBO(encoder, decoder, prior).cuda()
# Train
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
for epoch in range(32):
losses = []
for x, _ in trainloader:
x = x.round().flatten(-3).cuda()
loss = -vae(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
losses.append(loss.item())
print(f"Epoch {epoch}: {sum(losses) / len(losses):.4f}")
# Generate
vae.eval()
with torch.no_grad():
z = vae.prior().sample((16,))
x = vae.decoder(z).mean.reshape(-1, 28, 28)
Advanced Variations
Conditional VAE
Add conditioning to the decoder:
class ConditionalDecoder(zuko.lazy.LazyDistribution):
def __init__(self, features: int, latent: int, condition: int):
super().__init__()
self.hyper = nn.Sequential(
nn.Linear(latent + condition, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, features)
)
def forward(self, z: Tensor, c: Tensor):
zc = torch.cat([z, c], dim=-1)
return Independent(Bernoulli(logits=self.hyper(zc)), 1)
# Modify ELBO
class ConditionalELBO(nn.Module):
def forward(self, x: Tensor, c: Tensor) -> Tensor:
q = self.encoder(x)
z = q.rsample()
return (
self.decoder(z, c).log_prob(x) +
self.prior().log_prob(z) -
q.log_prob(z)
)
Beta-VAE
Weight the KL term for disentanglement:
class BetaELBO(nn.Module):
def __init__(self, encoder, decoder, prior, beta=1.0):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.prior = prior
self.beta = beta
def forward(self, x: Tensor) -> Tensor:
q = self.encoder(x)
z = q.rsample()
log_p_x = self.decoder(z).log_prob(x)
kl = q.log_prob(z) - self.prior().log_prob(z)
return log_p_x - self.beta * kl
Setting beta > 1 encourages more disentangled representations but may reduce reconstruction quality.
Hierarchical VAE
Multiple latent levels:
class HierarchicalELBO(nn.Module):
def forward(self, x: Tensor) -> Tensor:
# Encode to z2
q_z2 = self.encoder_2(x)
z2 = q_z2.rsample()
# Encode to z1 (conditioned on z2)
q_z1 = self.encoder_1(z2)
z1 = q_z1.rsample()
# Decode
log_p_x = self.decoder(z1).log_prob(x)
# Prior
log_p_z1 = self.prior_1().log_prob(z1)
log_p_z2_given_z1 = self.prior_2(z1).log_prob(z2)
# Posterior
log_q_z1 = q_z1.log_prob(z1)
log_q_z2 = q_z2.log_prob(z2)
return log_p_x + log_p_z1 + log_p_z2_given_z1 - log_q_z1 - log_q_z2
Reconstruction vs. Generation
Reconstruction
Visualize reconstructions to check encoding quality:
vae.eval()
with torch.no_grad():
x_original = test_images.cuda()
# Encode and decode
q = vae.encoder(x_original)
z = q.sample()
x_recon = vae.decoder(z).mean
# Compare
print("Original vs Reconstructed")
Generation
Generate completely new samples:
with torch.no_grad():
# Sample from learned prior
z = vae.prior().sample((64,))
x_gen = vae.decoder(z).mean
Interpolation
Interpolate between two datapoints:
with torch.no_grad():
# Encode two images
z1 = vae.encoder(x1).mean
z2 = vae.encoder(x2).mean
# Interpolate
alphas = torch.linspace(0, 1, 10)
z_interp = torch.stack([alpha * z1 + (1 - alpha) * z2 for alpha in alphas])
# Decode
x_interp = vae.decoder(z_interp).mean
Tips for Training VAEs
Posterior Collapse: If the KL term drops to zero early in training, the model is ignoring the latent code. Solutions:
- Increase latent dimension
- Use a more expressive prior (flow)
- Implement KL annealing
KL Annealing: Gradually increase the weight on the KL term:kl_weight = min(1.0, epoch / warmup_epochs)
loss = -log_p_x + kl_weight * kl
Architecture Balance: Match encoder and decoder capacity. An overly powerful decoder can ignore the latent code.
For image data, remember to normalize/binarize appropriately. MNIST should be binarized with x.round() when using a Bernoulli decoder.
Next Steps