This guide covers the fundamentals of training normalizing flows in Zuko, including setting up training loops, loss functions, and monitoring training progress.
Overview
Training a normalizing flow consists of maximizing the log-likelihood of the data under the flow distribution. This is equivalent to minimizing the negative log-likelihood (NLL) loss.
Basic Training Loop
Here’s a complete example of training a flow on a dataset:
import torch
import zuko
# Create a neural spline flow
flow = zuko.flows.NSF(
features=3, # number of sample features
context=5, # number of context features
transforms=3, # number of transformations
hidden_features=[128, 128] # hidden layer sizes
)
# Setup optimizer
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
# Training loop
for epoch in range(num_epochs):
for x, c in trainloader:
# Compute negative log-likelihood
loss = -flow(c).log_prob(x).mean()
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
The flow object flow(c) returns a distribution object. Remember to call flow(c) again after each optimizer step to get the updated distribution with new parameters.
Loss Functions
Negative Log-Likelihood
The primary loss function for training normalizing flows is the negative log-likelihood:
loss = -flow(c).log_prob(x) # Shape: (batch_size,)
loss = loss.mean() # Scalar loss
For unconditional flows, simply pass None as the context:
flow = zuko.flows.MAF(features=5, context=0)
loss = -flow().log_prob(x).mean()
Forward KL Divergence
When training a flow to approximate a target distribution where you can evaluate densities:
# Sample from target distribution
x = target_distribution.sample((batch_size,))
# Forward KL: E_p[-log q]
loss = -flow(c).log_prob(x).mean()
Reverse KL Divergence
When you can only evaluate the unnormalized target density:
# Sample from flow
dist = flow(c)
x = dist.rsample() # Use rsample for reparameterization
# Reverse KL: E_q[log q - log p]
log_q = dist.log_prob(x)
log_p = target_log_prob(x) # Unnormalized is fine
loss = (log_q - log_p).mean()
Use rsample() instead of sample() when you need gradients to flow through the samples. This is essential for reverse KL and variational inference.
Optimizer Setup
Adam Optimizer
Adam is the most commonly used optimizer for training flows:
optimizer = torch.optim.Adam(
flow.parameters(),
lr=1e-3,
weight_decay=1e-5 # Optional L2 regularization
)
Learning Rate Scheduling
Use a learning rate scheduler to improve convergence:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=10
)
for epoch in range(num_epochs):
# Training loop
train_loss = train_epoch(flow, trainloader, optimizer)
# Update learning rate
scheduler.step(train_loss)
Alternatively, use cosine annealing:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=num_epochs
)
Validation and Monitoring
Validation Loop
Monitor performance on a validation set:
def validate(flow, valloader):
flow.eval()
total_loss = 0
with torch.no_grad():
for x, c in valloader:
loss = -flow(c).log_prob(x).mean()
total_loss += loss.item()
flow.train()
return total_loss / len(valloader)
Training with Progress Monitoring
Initialize tracking
Track losses and metrics throughout training:train_losses = []
val_losses = []
Training epoch
Collect losses during training:for epoch in range(num_epochs):
epoch_losses = []
for x, c in trainloader:
loss = -flow(c).log_prob(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
epoch_losses.append(loss.item())
train_loss = sum(epoch_losses) / len(epoch_losses)
train_losses.append(train_loss)
Validation
Evaluate on validation set: val_loss = validate(flow, valloader)
val_losses.append(val_loss)
print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
Complete Training Example
Here’s a complete training script with all best practices:
import torch
import torch.utils.data as data
import zuko
# Prepare data
train_dataset = ... # Your dataset
val_dataset = ...
trainloader = data.DataLoader(train_dataset, batch_size=256, shuffle=True)
valloader = data.DataLoader(val_dataset, batch_size=256)
# Create flow
flow = zuko.flows.NSF(
features=10,
context=0,
transforms=5,
hidden_features=[256, 256]
).cuda()
# Setup training
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=10, factor=0.5
)
# Training loop
best_val_loss = float('inf')
for epoch in range(100):
# Training
flow.train()
train_losses = []
for x, _ in trainloader:
x = x.cuda()
loss = -flow().log_prob(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_losses.append(loss.item())
# Validation
flow.eval()
val_losses = []
with torch.no_grad():
for x, _ in valloader:
x = x.cuda()
loss = -flow().log_prob(x).mean()
val_losses.append(loss.item())
train_loss = sum(train_losses) / len(train_losses)
val_loss = sum(val_losses) / len(val_losses)
# Learning rate scheduling
scheduler.step(val_loss)
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(flow.state_dict(), 'best_flow.pth')
print(f"Epoch {epoch}: Train = {train_loss:.4f}, Val = {val_loss:.4f}")
Tips and Best Practices
Gradient Clipping: For stable training, especially with deep flows, use gradient clipping:torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0)
optimizer.step()
Batch Size: Larger batch sizes (256-512) typically work better for flow training as they provide more stable gradient estimates.
Initialization: Most Zuko flows have good default initialization. If training is unstable, try reducing the learning rate first.
Always remember to rebuild the distribution after optimizer steps:# Wrong
dist = flow(c)
for step in range(num_steps):
loss = -dist.log_prob(x).mean() # Uses old parameters!
...
# Correct
for step in range(num_steps):
loss = -flow(c).log_prob(x).mean() # Fresh distribution
...
Next Steps