Skip to main content
The zuko.bayesian module provides utilities for Bayesian deep learning, enabling variational inference over model parameters using mean-field Gaussian posteriors.

BayesianModel

Creates a Bayesian wrapper around a base model that maintains a variational posterior over parameters. The posterior is a mean-field Gaussian factorization: q(θ)=iN(θiμi,σi2)q(\theta) = \prod_i \mathcal{N}(\theta_i | \mu_i, \sigma_i^2) where μi\mu_i and logσi2\log \sigma_i^2 are learned variational parameters.
base
nn.Module
required
A base PyTorch model
init_logvar
float
default:"-9.0"
Initial value for log-variance parameters (controls initialization uncertainty)
include_params
Sequence[str]
default:"('')"
List of parameter name prefixes to include in the posterior. Use * to match alphanumeric strings and ** to match dot-separated paths. By default, all parameters are included.
exclude_params
Sequence[str]
default:"()"
List of parameter name prefixes to exclude from the posterior

Methods

sample_params

Returns model parameters sampled from the posterior. Returns: Dictionary mapping parameter names to sampled tensors
params = bayesian_model.sample_params()
# params = {'layer1.weight': tensor(...), 'layer1.bias': tensor(...), ...}

sample_model

Returns a standalone model sampled from the posterior. The returned model is a deep copy of the base model with sampled parameters. It can be used independently but does not propagate gradients to the Bayesian model. Returns: A sampled model instance
sampled = bayesian_model.sample_model()
y = sampled(x)  # Use like any PyTorch model
Warning: sample_model() should not be used during training as gradients do not flow back to the variational parameters. Use reparameterize() instead for training.

reparameterize

Context manager that temporarily reparameterizes the base model from the posterior. Within this context, the base model behaves deterministically (same inputs → same outputs) and gradients flow through the variational parameters.
local_trick
bool
default:"False"
Whether to use the local reparameterization trick for linear layers, which reduces variance in gradient estimates
Yields: The reparametrized base model
with bayesian_model.reparameterize() as model:
    y = model(x)  # Gradients flow to variational parameters
    loss = criterion(y, target)
    loss.backward()
Local Reparameterization Trick:When local_trick=True, instead of sampling weights and then computing outputs, the method samples activations directly from their induced distribution. This reduces gradient variance.Reference: Variational Dropout and the Local Reparameterization Trick (Kingma et al., 2015) - arxiv.org/abs/1506.02557

kl_divergence

Computes the KL divergence between the posterior and a Gaussian prior: DKL(q(θ)p(θ))=iDKL(N(μi,σi2)N(0,σprior2))D_{KL}(q(\theta) \| p(\theta)) = \sum_i D_{KL}(\mathcal{N}(\mu_i, \sigma_i^2) \| \mathcal{N}(0, \sigma_{prior}^2))
prior_var
float
default:"1.0"
The variance σprior2\sigma_{prior}^2 of the Gaussian prior
Returns: The KL divergence as a scalar tensor
kl = bayesian_model.kl_divergence(prior_var=1.0)

Usage Example

Basic Variational Inference

import torch
import torch.nn as nn
from zuko.bayesian import BayesianModel

# Create a base neural network
base_net = nn.Sequential(
    nn.Linear(10, 64),
    nn.ReLU(),
    nn.Linear(64, 1)
)

# Wrap with Bayesian inference
bayesian_net = BayesianModel(base_net, init_logvar=-9.0)

# Training loop with ELBO
optimizer = torch.optim.Adam(bayesian_net.parameters(), lr=1e-3)

for x, y in dataloader:
    optimizer.zero_grad()
    
    # Reparameterize for one forward pass
    with bayesian_net.reparameterize() as model:
        pred = model(x)
        nll = nn.functional.mse_loss(pred, y)
    
    # ELBO = negative log-likelihood + KL divergence
    kl = bayesian_net.kl_divergence() / len(dataset)
    loss = nll + kl
    
    loss.backward()
    optimizer.step()

Multiple Posterior Samples

# Ensemble predictions via multiple samples
num_samples = 10
predictions = []

for _ in range(num_samples):
    with torch.no_grad(), bayesian_net.reparameterize() as model:
        pred = model(x)
        predictions.append(pred)

# Compute mean and uncertainty
mean_pred = torch.stack(predictions).mean(dim=0)
std_pred = torch.stack(predictions).std(dim=0)

print(f"Prediction: {mean_pred} ± {std_pred}")

Selective Parameter Inference

# Only infer over final layer parameters
bayesian_net = BayesianModel(
    base_net,
    include_params=["2.*"],  # Only layer 2 parameters
    exclude_params=[]
)

# Or exclude specific parameters
bayesian_net = BayesianModel(
    base_net,
    include_params=[""],  # All parameters
    exclude_params=["*.bias"]  # Except biases
)

Local Reparameterization Trick

# Use local trick for lower variance gradients
for x, y in dataloader:
    optimizer.zero_grad()
    
    with bayesian_net.reparameterize(local_trick=True) as model:
        pred = model(x)
        nll = nn.functional.mse_loss(pred, y)
    
    kl = bayesian_net.kl_divergence() / len(dataset)
    loss = nll + kl
    
    loss.backward()
    optimizer.step()

Bayesian Normalizing Flows

Combining Bayesian inference with normalizing flows enables uncertainty quantification over flow parameters:
from zuko.flows import NSF
from zuko.bayesian import BayesianModel

# Create a normalizing flow
base_flow = NSF(features=2, context=3, transforms=5)

# Wrap with Bayesian inference
bayesian_flow = BayesianModel(base_flow, init_logvar=-9.0)

# Training with context
for x, c in dataloader:
    optimizer.zero_grad()
    
    with bayesian_flow.reparameterize() as flow:
        # Get conditional distribution
        dist = flow(c)
        nll = -dist.log_prob(x).mean()
    
    kl = bayesian_flow.kl_divergence() / len(dataset)
    loss = nll + kl
    
    loss.backward()
    optimizer.step()

# Sample multiple flow instances for uncertainty
sampled_flows = [bayesian_flow.sample_model() for _ in range(10)]

Pattern Matching

The include_params and exclude_params arguments support glob-like patterns:
  • "layer1" - Matches parameters starting with “layer1”
  • "*.weight" - Matches all weight parameters (single wildcard * matches [a-zA-Z0-9_]+)
  • "encoder.**" - Matches all parameters under “encoder” module (double wildcard ** matches [a-zA-Z0-9_\.]+)
  • "encoder.*.bias" - Matches biases in direct children of encoder
# Complex model with selective inference
model = nn.ModuleDict({
    'encoder': nn.Sequential(nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 32)),
    'decoder': nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))
})

# Only infer over encoder weights, not biases
bayesian_model = BayesianModel(
    model,
    include_params=["encoder.**"],
    exclude_params=["encoder.*.bias"]
)

Notes

ELBO Optimization:The evidence lower bound (ELBO) for variational inference is:L=Eq(θ)[logp(yx,θ)]DKL(q(θ)p(θ))\mathcal{L} = \mathbb{E}_{q(\theta)}[\log p(y|x, \theta)] - D_{KL}(q(\theta) \| p(\theta))In practice, maximize ELBO by minimizing: nll + kl_divergence() / num_datapointsThe KL term is scaled by the dataset size to balance the two terms appropriately.
Memory Efficiency:The reparameterize() context manager temporarily replaces parameters without creating permanent copies, making it memory-efficient for large models.

Build docs developers (and LLMs) love