zuko.mixtures module provides Gaussian mixture models (GMM) for flexible density estimation with support for different covariance structures and conditional modeling.
GMM
Creates a Gaussian mixture model that represents distributions as weighted sums of Gaussian components: Supports full, diagonal, and spherical covariance parameterizations, with optional context conditioning.The number of features (dimensionality of the data)
The number of context features. If 0, creates an unconditional GMM. If > 0, mixture parameters are predicted from context via an MLP.
The number of Gaussian components in the mixture
The type of covariance matrix parameterization:
'full': Full covariance matrices (most flexible)'diagonal': Diagonal covariance matrices (axis-aligned)'spherical': Scalar variance (isotropic)
Whether to tie (share) covariance parameters across all components
A numerical stability term added to variances
**kwargs
Keyword arguments passed to
zuko.nn.MLP (only used when context > 0)Methods
forward
Creates the mixture distribution.The context tensor with shape , where is the number of context features. If
None, returns the unconditional distribution.torch.distributions.Distribution object representing the mixture
initialize
Initializes the mixture components using clustering on data samples.Feature samples with shape where is the number of samples
The clustering initialization strategy:
'random': Randomly select component centers from data'kmeans': Run k-means clustering'kmeans++': Use k-means++ initialization (usually best)
Initialization Strategies:
- random: Fastest, but may lead to poor initial components
- kmeans: Iteratively refines component centers (7 iterations by default)
- kmeans++: Smart initialization that spreads centers apart, usually converges faster
Usage Examples
Unconditional Density Estimation
Conditional Density Estimation
Covariance Types Comparison
Tied Covariances
Advanced Usage
Mixture Weights and Responsibilities
Integration with Normalizing Flows
Parameter Shapes
The GMM internally represents its parameters differently based on configuration:| Config | Parameters | Shape |
|---|---|---|
components=K, features=D | Logits | (K,) |
| Means | (K, D) | |
| Full covariance | ||
tied=False | Diagonal | (K, D) |
| Off-diagonal | (K, D*(D-1)/2) | |
tied=True | Diagonal | (1, D) |
| Off-diagonal | (1, D*(D-1)/2) | |
| Diagonal covariance | ||
tied=False | Diagonal | (K, D) |
tied=True | Diagonal | (1, D) |
| Spherical covariance | ||
tied=False | Variance | (K, 1) |
tied=True | Variance | (1, 1) |
Notes
When to use GMMs:
- Multimodal data: GMMs naturally handle multiple modes
- Interpretable components: Each Gaussian has clear meaning
- Fast inference: Simpler than flows, faster sampling and density evaluation
- Limited data: Fewer parameters than complex flows
- Complex distributions: Non-Gaussian, heavy tails, intricate dependencies
- High dimensions: Flows scale better to many dimensions
- Exact likelihoods needed: Flows provide tractable exact densities
Numerical Stability:The
epsilon parameter prevents numerical issues:- Adds
epsilonto diagonal variance elements - Prevents singular covariance matrices
- Default
1e-6works well for normalized data
epsilon if you encounter numerical errors.