Skip to main content

Overview

Gaussian Mixture Model (GMM) represents a distribution as a weighted sum of Gaussian components. While technically not a normalizing flow, GMM is included in Zuko as a simple and interpretable density estimation method.
GMM is located in zuko.mixtures but is also available as zuko.flows.GMM for backwards compatibility.

Reference

Wikipedia: Gaussian Mixture Model
https://wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model

Class Definition

zuko.flows.GMM(
    features: int,
    context: int = 0,
    components: int = 2,
    covariance_type: str = "full",
    tied: bool = False,
    epsilon: float = 1e-6,
    **kwargs
)

Parameters

features
int
required
The number of features in the data.
context
int
default:"0"
The number of context features for conditional density estimation.
components
int
default:"2"
The number of Gaussian components K in the mixture.
covariance_type
str
default:"full"
The type of covariance matrix parameterization:
  • "full": Full covariance matrix (most expressive)
  • "diagonal": Diagonal covariance (axis-aligned)
  • "spherical": Single variance parameter (isotropic)
tied
bool
default:"False"
Whether to tie the covariance parameters across components. If True, all components share the same covariance structure.
epsilon
float
default:"1e-6"
A numerical stability term added to variances.
**kwargs
dict
Additional keyword arguments passed to the MLP constructor (for conditional GMMs):
  • hidden_features: Hidden layer sizes (default: [64, 64])
  • activation: Activation function

Usage Example

import torch
import zuko

# Create an unconditional GMM
gmm = zuko.flows.GMM(
    features=2,
    components=5,
    covariance_type="full"
)

# Sample from the GMM
dist = gmm()
samples = dist.sample((1000,))
print(samples.shape)  # torch.Size([1000, 2])

# Compute log probabilities
log_prob = dist.log_prob(samples)
print(log_prob.shape)  # torch.Size([1000])

Conditional GMM

# Create a conditional GMM
gmm = zuko.flows.GMM(
    features=2,
    context=5,
    components=3,
    covariance_type="full",
    hidden_features=[128, 128]
)

context = torch.randn(5)
dist = gmm(context)
samples = dist.sample((100,))

Training Example

import torch.optim as optim

gmm = zuko.flows.GMM(
    features=3,
    components=5,
    covariance_type="full"
)

optimizer = optim.Adam(gmm.parameters(), lr=1e-3)

for epoch in range(100):
    for x in dataloader:
        optimizer.zero_grad()
        loss = -gmm().log_prob(x).mean()
        loss.backward()
        optimizer.step()

Initialization with K-Means

import torch

# Create GMM
gmm = zuko.flows.GMM(
    features=2,
    components=3,
    covariance_type="full"
)

# Initialize with k-means clustering
data = torch.randn(1000, 2)
gmm.initialize(data, strategy="kmeans")

# Now train
optimizer = optim.Adam(gmm.parameters(), lr=1e-3)
# ... training loop ...

Methods

forward(c=None)

Returns a Gaussian mixture distribution. Arguments:
  • c (Tensor, optional): Context tensor of shape (*, context)
Returns:
  • Mixture: A mixture distribution with:
    • sample(shape): Sample from the mixture
    • log_prob(x): Compute log probability of samples
    • component_distribution: The underlying Gaussian components

initialize(x, strategy)

Initializes the GMM components using clustering. Arguments:
  • x (Tensor): Feature samples with shape (N, features)
  • strategy (str): Clustering strategy:
    • "random": Random initialization
    • "kmeans": K-means clustering
    • "kmeans++": K-means++ initialization
Example:
data = torch.randn(1000, 5)
gmm.initialize(data, strategy="kmeans++")

When to Use GMM

Good for:
  • Simple, interpretable density estimation
  • Clustering applications
  • Low-dimensional data (< 20 features)
  • When you know the number of modes
  • Baseline comparisons
  • Fast inference
Consider alternatives if:
  • You need high expressivity (use NSF or NAF)
  • You have high-dimensional data (use flows)
  • You don’t know the number of components
  • Data has complex, non-Gaussian structure

Tips

  1. Initialize properly: Use k-means initialization for better convergence.
  2. Choose components: Start with 3-10 components. Use cross-validation to select.
  3. Covariance type: Use “full” for small data, “diagonal” for medium, “spherical” for large/high-dim.
  4. Regularization: The epsilon parameter prevents singular covariances.

Model Equation

GMM represents the distribution as:
p(x | c) = sum_{i=1}^K w_i(c) * N(x | μ_i(c), Σ_i(c))
where:
  • K is the number of components
  • w_i are mixing weights (sum to 1)
  • N(μ_i, Σ_i) are Gaussian components
  • c is optional context

Covariance Types

Full Covariance

gmm = zuko.flows.GMM(features=3, components=5, covariance_type="full")
  • Most expressive
  • O(features^2) parameters per component
  • Can model correlations
  • Best for low-dimensional data

Diagonal Covariance

gmm = zuko.flows.GMM(features=3, components=5, covariance_type="diagonal")
  • Medium expressivity
  • O(features) parameters per component
  • Assumes features are independent
  • Good for medium-dimensional data

Spherical Covariance

gmm = zuko.flows.GMM(features=3, components=5, covariance_type="spherical")
  • Least expressive
  • O(1) parameters per component
  • Isotropic Gaussians (same variance in all directions)
  • Good for high-dimensional data

Tied vs. Untied Covariances

Untied (default)

gmm = zuko.flows.GMM(features=3, components=5, tied=False)
Each component has its own covariance matrix.

Tied

gmm = zuko.flows.GMM(features=3, components=5, tied=True)
All components share the same covariance structure (only means differ).

Initialization Strategies

Random

gmm.initialize(data, strategy="random")
Randomly select data points as initial centers.

K-Means

gmm.initialize(data, strategy="kmeans")
Run k-means algorithm for initialization (recommended).

K-Means++

gmm.initialize(data, strategy="kmeans++")
Use k-means++ for better initialization (best).

Advanced Usage

Model Selection

import torch
from torch.utils.data import DataLoader

# Try different numbers of components
for K in [3, 5, 7, 10]:
    gmm = zuko.flows.GMM(features=5, components=K)
    gmm.initialize(train_data, strategy="kmeans++")
    
    optimizer = optim.Adam(gmm.parameters(), lr=1e-3)
    # ... train ...
    
    # Evaluate on validation set
    val_log_prob = gmm().log_prob(val_data).mean()
    print(f"K={K}: {val_log_prob:.4f}")

Extract Cluster Assignments

import torch

gmm = zuko.flows.GMM(features=2, components=5)
# ... train ...

# Get most likely component for each point
data = torch.randn(100, 2)
dist = gmm()

# Compute component probabilities
component_logits = dist.component_distribution.log_prob(data.unsqueeze(1))
component_probs = torch.softmax(component_logits + dist.mixture_distribution.logits, dim=-1)

# Hard assignment
cluster_assignment = component_probs.argmax(dim=-1)
print(cluster_assignment)  # Shape: [100]

Conditional GMM for Regression

# Use GMM for probabilistic regression
gmm = zuko.flows.GMM(
    features=1,        # Target variable
    context=10,        # Input features
    components=5,
    covariance_type="spherical"
)

# Train
for x_input, y_target in dataloader:
    optimizer.zero_grad()
    loss = -gmm(x_input).log_prob(y_target).mean()
    loss.backward()
    optimizer.step()

# Predict
x_new = torch.randn(10)
predictive_dist = gmm(x_new)
y_samples = predictive_dist.sample((1000,))
y_mean = y_samples.mean()
y_std = y_samples.std()

Visualization

import matplotlib.pyplot as plt
import numpy as np
import torch

# Train 2D GMM
gmm = zuko.flows.GMM(features=2, components=3, covariance_type="full")
# ... train ...

# Create grid
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
pos = np.stack([X, Y], axis=-1)
pos_tensor = torch.tensor(pos, dtype=torch.float32)

# Compute log probabilities
with torch.no_grad():
    log_prob = gmm().log_prob(pos_tensor)
    prob = log_prob.exp()

# Plot
plt.figure(figsize=(8, 6))
plt.contourf(X, Y, prob.numpy(), levels=20, cmap='viridis')
plt.colorbar(label='Density')
plt.title('GMM Density')
plt.xlabel('x1')
plt.ylabel('x2')
plt.show()

Comparison with Flows

PropertyGMMNSF/MAF (Flows)
ExpressivityLow-MediumHigh
InterpretabilityHighLow
Training speedFastMedium-Slow
Inference speedFastMedium
ScalabilityLow (< 20D)High (100s D)
ClusteringYesNo
FlexibilityLimitedHigh

Applications

Clustering

# Use GMM for soft clustering
gmm = zuko.flows.GMM(features=10, components=5)
# ... train ...

# Get cluster probabilities
data = torch.randn(1000, 10)
cluster_probs = # ... extract from dist ...

Anomaly Detection

# Train on normal data
gmm = zuko.flows.GMM(features=5, components=3)
# ... train on normal data ...

# Detect anomalies
test_data = torch.randn(100, 5)
log_prob = gmm().log_prob(test_data)
anomalies = log_prob < threshold

Generative Modeling

# Simple generative model
gmm = zuko.flows.GMM(features=2, components=10)
# ... train ...

# Generate new samples
samples = gmm().sample((1000,))

Limitations

Key limitations:
  1. Fixed components: Must specify number of components in advance
  2. Gaussian assumption: Each component is Gaussian
  3. Low capacity: Limited expressivity compared to flows
  4. Scalability: Not suitable for very high-dimensional data
  5. Local optima: EM-based training can get stuck

Build docs developers (and LLMs) love