Skip to main content
The zuko.mixtures module provides Gaussian mixture models (GMM) for flexible density estimation with support for different covariance structures and conditional modeling.

GMM

Creates a Gaussian mixture model that represents distributions as weighted sums of Gaussian components: p(Xc)=i=1Kwi(c)N(Xμi(c),Σi(c))p(X | c) = \sum_{i=1}^K w_i(c) \, \mathcal{N}(X | \mu_i(c), \Sigma_i(c)) Supports full, diagonal, and spherical covariance parameterizations, with optional context conditioning.
features
int
required
The number of features (dimensionality of the data)
context
int
default:"0"
The number of context features. If 0, creates an unconditional GMM. If > 0, mixture parameters are predicted from context via an MLP.
components
int
default:"2"
The number of Gaussian components KK in the mixture
covariance_type
Literal['full', 'diagonal', 'spherical']
default:"'full'"
The type of covariance matrix parameterization:
  • 'full': Full covariance matrices (most flexible)
  • 'diagonal': Diagonal covariance matrices (axis-aligned)
  • 'spherical': Scalar variance (isotropic)
tied
bool
default:"False"
Whether to tie (share) covariance parameters across all components
epsilon
float
default:"1e-6"
A numerical stability term added to variances
**kwargs
Keyword arguments passed to zuko.nn.MLP (only used when context > 0)

Methods

forward

Creates the mixture distribution.
c
Tensor | None
default:"None"
The context tensor with shape (,C)(*, C), where CC is the number of context features. If None, returns the unconditional distribution.
Returns: A torch.distributions.Distribution object representing the mixture
import torch
from zuko.mixtures import GMM

# Unconditional GMM
gmm = GMM(features=2, components=3)
dist = gmm()  # No context needed

# Conditional GMM
cgmm = GMM(features=2, context=5, components=3)
c = torch.randn(10, 5)  # Context for 10 samples
dist = cgmm(c)  # Conditional distribution

initialize

Initializes the mixture components using clustering on data samples.
x
Tensor
required
Feature samples with shape (N,D)(N, D) where NN is the number of samples
strategy
Literal['random', 'kmeans', 'kmeans++']
required
The clustering initialization strategy:
  • 'random': Randomly select component centers from data
  • 'kmeans': Run k-means clustering
  • 'kmeans++': Use k-means++ initialization (usually best)
import torch
from zuko.mixtures import GMM

# Create GMM
gmm = GMM(features=2, components=3)

# Initialize from data
data = torch.randn(1000, 2)
gmm.initialize(data, strategy='kmeans++')

# Now the model is initialized near the data
Initialization Strategies:
  • random: Fastest, but may lead to poor initial components
  • kmeans: Iteratively refines component centers (7 iterations by default)
  • kmeans++: Smart initialization that spreads centers apart, usually converges faster
For more details, see scikit-learn’s mixture documentation.

Usage Examples

Unconditional Density Estimation

import torch
from zuko.mixtures import GMM

# Create a 2D GMM with 4 components
gmm = GMM(features=2, components=4, covariance_type='full')

# Generate training data
data = torch.randn(1000, 2)

# Initialize components from data
gmm.initialize(data, strategy='kmeans++')

# Training loop
optimizer = torch.optim.Adam(gmm.parameters(), lr=1e-3)

for epoch in range(100):
    optimizer.zero_grad()
    
    # Get distribution
    dist = gmm()
    
    # Negative log-likelihood
    loss = -dist.log_prob(data).mean()
    
    loss.backward()
    optimizer.step()

# Sample from the fitted model
samples = gmm().sample((100,))

# Evaluate density
log_prob = gmm().log_prob(data)

Conditional Density Estimation

import torch
from zuko.mixtures import GMM

# Conditional GMM: model p(y | x)
gmm = GMM(
    features=1,      # 1D output
    context=3,       # 3D input
    components=5,
    covariance_type='diagonal',
    hidden_features=[64, 64]  # MLP architecture
)

# Training data
x = torch.randn(1000, 3)  # Inputs
y = torch.randn(1000, 1)  # Outputs

optimizer = torch.optim.Adam(gmm.parameters(), lr=1e-3)

for epoch in range(100):
    optimizer.zero_grad()
    
    # Conditional distribution
    dist = gmm(x)
    
    # Negative log-likelihood
    loss = -dist.log_prob(y).mean()
    
    loss.backward()
    optimizer.step()

# Conditional sampling
x_test = torch.randn(10, 3)
dist_test = gmm(x_test)
samples = dist_test.sample()  # Sample y given x_test

Covariance Types Comparison

import torch
from zuko.mixtures import GMM

data = torch.randn(1000, 3)

# Full covariance (most parameters)
gmm_full = GMM(features=3, components=4, covariance_type='full')
print(f"Full params: {sum(p.numel() for p in gmm_full.parameters())}")

# Diagonal covariance (medium parameters)
gmm_diag = GMM(features=3, components=4, covariance_type='diagonal')
print(f"Diagonal params: {sum(p.numel() for p in gmm_diag.parameters())}")

# Spherical covariance (fewest parameters)
gmm_sphere = GMM(features=3, components=4, covariance_type='spherical')
print(f"Spherical params: {sum(p.numel() for p in gmm_sphere.parameters())}")

# Full: Most flexible, can capture correlations
# Diagonal: Axis-aligned ellipsoids, no correlations
# Spherical: Circular/spherical components, single variance per component

Tied Covariances

import torch
from zuko.mixtures import GMM

# Tied covariances: all components share the same covariance structure
gmm_tied = GMM(
    features=2,
    components=5,
    covariance_type='full',
    tied=True  # Share covariance across components
)

data = torch.randn(1000, 2)
gmm_tied.initialize(data, strategy='kmeans++')

# This reduces parameters significantly
# Useful when components have similar shapes but different locations
print(f"Tied params: {sum(p.numel() for p in gmm_tied.parameters())}")

# Compare to untied
gmm_untied = GMM(features=2, components=5, covariance_type='full', tied=False)
print(f"Untied params: {sum(p.numel() for p in gmm_untied.parameters())}")

Advanced Usage

Mixture Weights and Responsibilities

import torch
from zuko.mixtures import GMM

# Create and train GMM
gmm = GMM(features=2, components=3)
data = torch.randn(1000, 2)
gmm.initialize(data, strategy='kmeans++')

# Get the distribution
dist = gmm()

# Access mixture components
mixture = dist  # This is a zuko.distributions.Mixture
components = mixture.component_distribution  # Base Gaussian components
logits = mixture.mixture_distribution.logits  # Log mixture weights

# Compute responsibilities (posterior probabilities)
log_probs = components.log_prob(data.unsqueeze(-1))  # (N, K)
log_weights = torch.log_softmax(logits, dim=-1)  # (K,)
log_responsibilities = log_weights + log_probs  # (N, K)
responsibilities = torch.softmax(log_responsibilities, dim=-1)

print(f"Responsibilities shape: {responsibilities.shape}")  # (1000, 3)
print(f"Each row sums to 1: {responsibilities.sum(dim=-1)[0]}")

Integration with Normalizing Flows

import torch
from zuko.mixtures import GMM
from zuko.flows import NSF

# Use GMM as a flexible base distribution for flows
base_gmm = GMM(features=2, components=5, covariance_type='full')
flow = NSF(
    features=2,
    context=3,
    transforms=5
)

# Or create a conditional GMM for the base
conditional_base = GMM(features=2, context=3, components=5)

# The combination provides even more flexibility
# GMM captures multimodality, flow adds flexibility via transformations

Parameter Shapes

The GMM internally represents its parameters differently based on configuration:
ConfigParametersShape
components=K, features=DLogits(K,)
Means(K, D)
Full covariance
tied=FalseDiagonal(K, D)
Off-diagonal(K, D*(D-1)/2)
tied=TrueDiagonal(1, D)
Off-diagonal(1, D*(D-1)/2)
Diagonal covariance
tied=FalseDiagonal(K, D)
tied=TrueDiagonal(1, D)
Spherical covariance
tied=FalseVariance(K, 1)
tied=TrueVariance(1, 1)

Notes

When to use GMMs:
  • Multimodal data: GMMs naturally handle multiple modes
  • Interpretable components: Each Gaussian has clear meaning
  • Fast inference: Simpler than flows, faster sampling and density evaluation
  • Limited data: Fewer parameters than complex flows
When to use flows instead:
  • Complex distributions: Non-Gaussian, heavy tails, intricate dependencies
  • High dimensions: Flows scale better to many dimensions
  • Exact likelihoods needed: Flows provide tractable exact densities
Numerical Stability:The epsilon parameter prevents numerical issues:
  • Adds epsilon to diagonal variance elements
  • Prevents singular covariance matrices
  • Default 1e-6 works well for normalized data
Increase epsilon if you encounter numerical errors.

Build docs developers (and LLMs) love