Skip to main content
This guide shows you how to integrate normalizing flows into Variational Autoencoders (VAEs) to create more expressive generative models.

VAE Fundamentals

A Variational Autoencoder consists of three components:
  1. Encoder qψ(zx)q_\psi(z | x): Maps data to latent space
  2. Decoder pϕ(xz)p_\phi(x | z): Reconstructs data from latent codes
  3. Prior pϕ(z)p_\phi(z): Distribution over latent codes
The training objective is to maximize the Evidence Lower Bound (ELBO): ELBO(x)=Eqψ(zx)[logpϕ(xz)+logpϕ(z)logqψ(zx)]\text{ELBO}(x) = \mathbb{E}_{q_\psi(z|x)} [\log p_\phi(x|z) + \log p_\phi(z) - \log q_\psi(z|x)]

Why Use Flow Priors?

Standard VAEs use a simple Gaussian prior p(z)=N(0,I)p(z) = \mathcal{N}(0, I). Normalizing flows provide a more expressive prior, allowing the model to learn complex latent structures.

Implementing the ELBO

First, let’s implement a general ELBO module:
import torch
import torch.nn as nn
from torch import Tensor
import zuko

class ELBO(nn.Module):
    def __init__(
        self,
        encoder: zuko.lazy.LazyDistribution,
        decoder: zuko.lazy.LazyDistribution,
        prior: zuko.lazy.LazyDistribution,
    ):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.prior = prior
    
    def forward(self, x: Tensor) -> Tensor:
        # Encode
        q = self.encoder(x)
        z = q.rsample()  # Reparameterization trick
        
        # ELBO = E[log p(x|z)] + E[log p(z)] - E[log q(z|x)]
        log_p_x_given_z = self.decoder(z).log_prob(x)
        log_p_z = self.prior().log_prob(z)
        log_q_z_given_x = q.log_prob(z)
        
        return log_p_x_given_z + log_p_z - log_q_z_given_x
The encoder uses rsample() for the reparameterization trick, enabling gradients to flow through the sampling operation.

Building VAE Components

Gaussian Encoder

A diagonal Gaussian encoder:
from torch.distributions import Independent, Normal

class GaussianEncoder(zuko.lazy.LazyDistribution):
    def __init__(self, features: int, context: int):
        super().__init__()
        
        self.hyper = nn.Sequential(
            nn.Linear(context, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2 * features)
        )
    
    def forward(self, c: Tensor):
        params = self.hyper(c)
        mu, log_sigma = params.chunk(2, dim=-1)
        return Independent(Normal(mu, log_sigma.exp()), 1)

Bernoulli Decoder

For binary data (e.g., MNIST):
from torch.distributions import Independent, Bernoulli

class BernoulliDecoder(zuko.lazy.LazyDistribution):
    def __init__(self, features: int, context: int):
        super().__init__()
        
        self.hyper = nn.Sequential(
            nn.Linear(context, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, features)
        )
    
    def forward(self, c: Tensor):
        logits = self.hyper(c)
        probs = torch.sigmoid(logits)
        return Independent(Bernoulli(probs), 1)

Gaussian Decoder

For continuous data:
class GaussianDecoder(zuko.lazy.LazyDistribution):
    def __init__(self, features: int, context: int):
        super().__init__()
        
        self.hyper = nn.Sequential(
            nn.Linear(context, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2 * features)
        )
    
    def forward(self, c: Tensor):
        params = self.hyper(c)
        mu, log_sigma = params.chunk(2, dim=-1)
        return Independent(Normal(mu, log_sigma.exp()), 1)

VAE with Flow Prior

Now we can build a VAE with a normalizing flow prior:
import zuko.flows

# Define dimensions
latent_dim = 16
data_dim = 784  # e.g., flattened 28x28 MNIST images

# Encoder: q(z|x)
encoder = GaussianEncoder(features=latent_dim, context=data_dim)

# Decoder: p(x|z)  
decoder = BernoulliDecoder(features=data_dim, context=latent_dim)

# Prior: p(z) - Normalizing Flow
prior = zuko.flows.MAF(
    features=latent_dim,
    context=0,  # Unconditional
    transforms=3,
    hidden_features=(256, 256)
)

# Complete VAE
vae = ELBO(encoder, decoder, prior)
For the prior, MAF (Masked Autoregressive Flow) works well for moderate latent dimensions (up to ~50). For higher dimensions, consider NSF (Neural Spline Flow).

Training the VAE

1

Prepare data

import torch.utils.data as data
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor

trainset = MNIST(
    root="./data",
    download=True,
    train=True,
    transform=to_tensor
)

trainloader = data.DataLoader(
    trainset,
    batch_size=256,
    shuffle=True
)
2

Training loop

vae = vae.cuda()
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

for epoch in range(64):
    epoch_losses = []
    
    for x, _ in trainloader:
        # Prepare data (binarize and flatten MNIST)
        x = x.round().flatten(-3).cuda()
        
        # Compute negative ELBO
        elbo = vae(x)
        loss = -elbo.mean()
        
        # Optimize
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_losses.append(loss.item())
    
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
3

Generate samples

After training, generate new samples:
vae.eval()

with torch.no_grad():
    # Sample from prior
    z = vae.prior().sample((16,))
    
    # Decode to data space
    x_generated = vae.decoder(z).mean
    
    # Reshape for visualization (if using images)
    images = x_generated.reshape(-1, 28, 28)

Complete MNIST Example

Here’s a complete working example for MNIST:
import torch
import torch.nn as nn
import torch.utils.data as data
from torch import Tensor
from torch.distributions import Bernoulli, Independent, Normal
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor
import zuko

# ELBO Module
class ELBO(nn.Module):
    def __init__(self, encoder, decoder, prior):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.prior = prior
    
    def forward(self, x: Tensor) -> Tensor:
        q = self.encoder(x)
        z = q.rsample()
        return self.decoder(z).log_prob(x) + self.prior().log_prob(z) - q.log_prob(z)

# Encoder
class GaussianEncoder(zuko.lazy.LazyDistribution):
    def __init__(self, features: int, context: int):
        super().__init__()
        self.hyper = nn.Sequential(
            nn.Linear(context, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2 * features),
        )
    
    def forward(self, c: Tensor):
        mu, log_sigma = self.hyper(c).chunk(2, dim=-1)
        return Independent(Normal(mu, log_sigma.exp()), 1)

# Decoder
class BernoulliDecoder(zuko.lazy.LazyDistribution):
    def __init__(self, features: int, context: int):
        super().__init__()
        self.hyper = nn.Sequential(
            nn.Linear(context, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, features),
        )
    
    def forward(self, c: Tensor):
        return Independent(Bernoulli(logits=self.hyper(c)), 1)

# Data
trainset = MNIST(root="./data", download=True, train=True, transform=to_tensor)
trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)

# Model
encoder = GaussianEncoder(16, 784)
decoder = BernoulliDecoder(784, 16)
prior = zuko.flows.MAF(features=16, transforms=3, hidden_features=(256, 256))
vae = ELBO(encoder, decoder, prior).cuda()

# Train
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

for epoch in range(32):
    losses = []
    for x, _ in trainloader:
        x = x.round().flatten(-3).cuda()
        loss = -vae(x).mean()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
    
    print(f"Epoch {epoch}: {sum(losses) / len(losses):.4f}")

# Generate
vae.eval()
with torch.no_grad():
    z = vae.prior().sample((16,))
    x = vae.decoder(z).mean.reshape(-1, 28, 28)

Advanced Variations

Conditional VAE

Add conditioning to the decoder:
class ConditionalDecoder(zuko.lazy.LazyDistribution):
    def __init__(self, features: int, latent: int, condition: int):
        super().__init__()
        self.hyper = nn.Sequential(
            nn.Linear(latent + condition, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, features)
        )
    
    def forward(self, z: Tensor, c: Tensor):
        zc = torch.cat([z, c], dim=-1)
        return Independent(Bernoulli(logits=self.hyper(zc)), 1)

# Modify ELBO
class ConditionalELBO(nn.Module):
    def forward(self, x: Tensor, c: Tensor) -> Tensor:
        q = self.encoder(x)
        z = q.rsample()
        return (
            self.decoder(z, c).log_prob(x) + 
            self.prior().log_prob(z) - 
            q.log_prob(z)
        )

Beta-VAE

Weight the KL term for disentanglement:
class BetaELBO(nn.Module):
    def __init__(self, encoder, decoder, prior, beta=1.0):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.prior = prior
        self.beta = beta
    
    def forward(self, x: Tensor) -> Tensor:
        q = self.encoder(x)
        z = q.rsample()
        
        log_p_x = self.decoder(z).log_prob(x)
        kl = q.log_prob(z) - self.prior().log_prob(z)
        
        return log_p_x - self.beta * kl
Setting beta > 1 encourages more disentangled representations but may reduce reconstruction quality.

Hierarchical VAE

Multiple latent levels:
class HierarchicalELBO(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        # Encode to z2
        q_z2 = self.encoder_2(x)
        z2 = q_z2.rsample()
        
        # Encode to z1 (conditioned on z2)
        q_z1 = self.encoder_1(z2)
        z1 = q_z1.rsample()
        
        # Decode
        log_p_x = self.decoder(z1).log_prob(x)
        
        # Prior
        log_p_z1 = self.prior_1().log_prob(z1)
        log_p_z2_given_z1 = self.prior_2(z1).log_prob(z2)
        
        # Posterior
        log_q_z1 = q_z1.log_prob(z1)
        log_q_z2 = q_z2.log_prob(z2)
        
        return log_p_x + log_p_z1 + log_p_z2_given_z1 - log_q_z1 - log_q_z2

Reconstruction vs. Generation

Reconstruction

Visualize reconstructions to check encoding quality:
vae.eval()
with torch.no_grad():
    x_original = test_images.cuda()
    
    # Encode and decode
    q = vae.encoder(x_original)
    z = q.sample()
    x_recon = vae.decoder(z).mean
    
    # Compare
    print("Original vs Reconstructed")

Generation

Generate completely new samples:
with torch.no_grad():
    # Sample from learned prior
    z = vae.prior().sample((64,))
    x_gen = vae.decoder(z).mean

Interpolation

Interpolate between two datapoints:
with torch.no_grad():
    # Encode two images
    z1 = vae.encoder(x1).mean
    z2 = vae.encoder(x2).mean
    
    # Interpolate
    alphas = torch.linspace(0, 1, 10)
    z_interp = torch.stack([alpha * z1 + (1 - alpha) * z2 for alpha in alphas])
    
    # Decode
    x_interp = vae.decoder(z_interp).mean

Tips for Training VAEs

Posterior Collapse: If the KL term drops to zero early in training, the model is ignoring the latent code. Solutions:
  • Increase latent dimension
  • Use a more expressive prior (flow)
  • Implement KL annealing
KL Annealing: Gradually increase the weight on the KL term:
kl_weight = min(1.0, epoch / warmup_epochs)
loss = -log_p_x + kl_weight * kl
Architecture Balance: Match encoder and decoder capacity. An overly powerful decoder can ignore the latent code.
For image data, remember to normalize/binarize appropriately. MNIST should be binarized with x.round() when using a Bernoulli decoder.

Next Steps

Build docs developers (and LLMs) love