Skip to main content
This guide covers the fundamentals of training normalizing flows in Zuko, including setting up training loops, loss functions, and monitoring training progress.

Overview

Training a normalizing flow consists of maximizing the log-likelihood of the data under the flow distribution. This is equivalent to minimizing the negative log-likelihood (NLL) loss.

Basic Training Loop

Here’s a complete example of training a flow on a dataset:
import torch
import zuko

# Create a neural spline flow
flow = zuko.flows.NSF(
    features=3,      # number of sample features
    context=5,       # number of context features
    transforms=3,    # number of transformations
    hidden_features=[128, 128]  # hidden layer sizes
)

# Setup optimizer
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

# Training loop
for epoch in range(num_epochs):
    for x, c in trainloader:
        # Compute negative log-likelihood
        loss = -flow(c).log_prob(x).mean()
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
The flow object flow(c) returns a distribution object. Remember to call flow(c) again after each optimizer step to get the updated distribution with new parameters.

Loss Functions

Negative Log-Likelihood

The primary loss function for training normalizing flows is the negative log-likelihood:
loss = -flow(c).log_prob(x)  # Shape: (batch_size,)
loss = loss.mean()           # Scalar loss
For unconditional flows, simply pass None as the context:
flow = zuko.flows.MAF(features=5, context=0)
loss = -flow().log_prob(x).mean()

Forward KL Divergence

When training a flow to approximate a target distribution where you can evaluate densities:
# Sample from target distribution
x = target_distribution.sample((batch_size,))

# Forward KL: E_p[-log q]
loss = -flow(c).log_prob(x).mean()

Reverse KL Divergence

When you can only evaluate the unnormalized target density:
# Sample from flow
dist = flow(c)
x = dist.rsample()  # Use rsample for reparameterization

# Reverse KL: E_q[log q - log p]
log_q = dist.log_prob(x)
log_p = target_log_prob(x)  # Unnormalized is fine
loss = (log_q - log_p).mean()
Use rsample() instead of sample() when you need gradients to flow through the samples. This is essential for reverse KL and variational inference.

Optimizer Setup

Adam Optimizer

Adam is the most commonly used optimizer for training flows:
optimizer = torch.optim.Adam(
    flow.parameters(),
    lr=1e-3,
    weight_decay=1e-5  # Optional L2 regularization
)

Learning Rate Scheduling

Use a learning rate scheduler to improve convergence:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=10
)

for epoch in range(num_epochs):
    # Training loop
    train_loss = train_epoch(flow, trainloader, optimizer)
    
    # Update learning rate
    scheduler.step(train_loss)
Alternatively, use cosine annealing:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs
)

Validation and Monitoring

Validation Loop

Monitor performance on a validation set:
def validate(flow, valloader):
    flow.eval()
    total_loss = 0
    
    with torch.no_grad():
        for x, c in valloader:
            loss = -flow(c).log_prob(x).mean()
            total_loss += loss.item()
    
    flow.train()
    return total_loss / len(valloader)

Training with Progress Monitoring

1

Initialize tracking

Track losses and metrics throughout training:
train_losses = []
val_losses = []
2

Training epoch

Collect losses during training:
for epoch in range(num_epochs):
    epoch_losses = []
    
    for x, c in trainloader:
        loss = -flow(c).log_prob(x).mean()
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_losses.append(loss.item())
    
    train_loss = sum(epoch_losses) / len(epoch_losses)
    train_losses.append(train_loss)
3

Validation

Evaluate on validation set:
    val_loss = validate(flow, valloader)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

Complete Training Example

Here’s a complete training script with all best practices:
import torch
import torch.utils.data as data
import zuko

# Prepare data
train_dataset = ...  # Your dataset
val_dataset = ...

trainloader = data.DataLoader(train_dataset, batch_size=256, shuffle=True)
valloader = data.DataLoader(val_dataset, batch_size=256)

# Create flow
flow = zuko.flows.NSF(
    features=10,
    context=0,
    transforms=5,
    hidden_features=[256, 256]
).cuda()

# Setup training
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, patience=10, factor=0.5
)

# Training loop
best_val_loss = float('inf')

for epoch in range(100):
    # Training
    flow.train()
    train_losses = []
    
    for x, _ in trainloader:
        x = x.cuda()
        
        loss = -flow().log_prob(x).mean()
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        train_losses.append(loss.item())
    
    # Validation
    flow.eval()
    val_losses = []
    
    with torch.no_grad():
        for x, _ in valloader:
            x = x.cuda()
            loss = -flow().log_prob(x).mean()
            val_losses.append(loss.item())
    
    train_loss = sum(train_losses) / len(train_losses)
    val_loss = sum(val_losses) / len(val_losses)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(flow.state_dict(), 'best_flow.pth')
    
    print(f"Epoch {epoch}: Train = {train_loss:.4f}, Val = {val_loss:.4f}")

Tips and Best Practices

Gradient Clipping: For stable training, especially with deep flows, use gradient clipping:
torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0)
optimizer.step()
Batch Size: Larger batch sizes (256-512) typically work better for flow training as they provide more stable gradient estimates.
Initialization: Most Zuko flows have good default initialization. If training is unstable, try reducing the learning rate first.
Always remember to rebuild the distribution after optimizer steps:
# Wrong
dist = flow(c)
for step in range(num_steps):
    loss = -dist.log_prob(x).mean()  # Uses old parameters!
    ...

# Correct
for step in range(num_steps):
    loss = -flow(c).log_prob(x).mean()  # Fresh distribution
    ...

Next Steps

Build docs developers (and LLMs) love