Skip to main content

Overview

The Conditional Flow Matching (CFM) module implements the core generative component of Matcha-TTS. It learns to transform noise into high-quality mel-spectrograms conditioned on encoder outputs, using optimal transport conditional flow matching.

Class Definition

CFM

from matcha.models.components.flow_matching import CFM

cfm = CFM(
    in_channels=160,
    out_channel=80,
    cfm_params=cfm_config,
    decoder_params=decoder_config,
    n_spks=1,
    spk_emb_dim=64
)

Constructor Parameters

in_channels
int
required
Input channels to the decoder (typically 2 * n_feats for concatenating noise and mean)
out_channel
int
required
Output channels (number of mel-spectrogram features, typically 80)
cfm_params
object
required
CFM configuration parameters:
  • solver: ODE solver type (“euler”)
  • sigma_min: Minimum noise level (default: 1e-4)
decoder_params
object
required
Decoder network parameters (U-Net architecture)
n_spks
int
default:"1"
Number of speakers
spk_emb_dim
int
default:"64"
Speaker embedding dimension

Methods

forward()

Generates mel-spectrogram using the flow matching process (inference mode).
@torch.inference_mode()
def forward(
    mu: torch.Tensor,
    mask: torch.Tensor,
    n_timesteps: int,
    temperature: float = 1.0,
    spks: torch.Tensor = None,
    cond: torch.Tensor = None
) -> torch.Tensor

Parameters

mu
torch.Tensor
required
Output of the text encoder (mean latent representation)Shape: (batch_size, n_feats, mel_timesteps)
mask
torch.Tensor
required
Output mask for valid framesShape: (batch_size, 1, mel_timesteps)
n_timesteps
int
required
Number of ODE solver steps. More steps = higher quality but slower. Typical range: 4-20
temperature
float
default:"1.0"
Temperature for scaling initial noise. Higher = more diverse output
spks
torch.Tensor
default:"None"
Speaker embeddings for multi-speaker modelsShape: (batch_size, spk_emb_dim)
cond
torch.Tensor
default:"None"
Additional conditioning (reserved for future use)

Returns

sample
torch.Tensor
Generated mel-spectrogramShape: (batch_size, n_feats, mel_timesteps)

compute_loss()

Computes the conditional flow matching loss during training.
def compute_loss(
    x1: torch.Tensor,
    mask: torch.Tensor,
    mu: torch.Tensor,
    spks: torch.Tensor = None,
    cond: torch.Tensor = None
) -> tuple

Parameters

x1
torch.Tensor
required
Target mel-spectrogram (ground truth)Shape: (batch_size, n_feats, mel_timesteps)
mask
torch.Tensor
required
Target mask for valid framesShape: (batch_size, 1, mel_timesteps)
mu
torch.Tensor
required
Output of the encoder (mean)Shape: (batch_size, n_feats, mel_timesteps)
spks
torch.Tensor
default:"None"
Speaker embeddingsShape: (batch_size, spk_emb_dim)
cond
torch.Tensor
default:"None"
Additional conditioning (reserved for future use)

Returns

loss
torch.Tensor
Conditional flow matching loss (MSE between predicted and target flow)
y
torch.Tensor
Intermediate noisy sample at random timestepShape: (batch_size, n_feats, mel_timesteps)

solve_euler()

Euler method ODE solver for the probability flow.
def solve_euler(
    x: torch.Tensor,
    t_span: torch.Tensor,
    mu: torch.Tensor,
    mask: torch.Tensor,
    spks: torch.Tensor,
    cond: torch.Tensor
) -> torch.Tensor

Parameters

x
torch.Tensor
required
Initial noise sampleShape: (batch_size, n_feats, mel_timesteps)
t_span
torch.Tensor
required
Time steps for ODE solver from 0 to 1Shape: (n_timesteps + 1,)
mu
torch.Tensor
required
Encoder output (conditioning)Shape: (batch_size, n_feats, mel_timesteps)
mask
torch.Tensor
required
Output maskShape: (batch_size, 1, mel_timesteps)
spks
torch.Tensor
required
Speaker embeddings
cond
torch.Tensor
required
Additional conditioning

Returns

output
torch.Tensor
Final generated mel-spectrogram after solving ODEShape: (batch_size, n_feats, mel_timesteps)

Base Class

BASECFM

Abstract base class for conditional flow matching implementations.
from matcha.models.components.flow_matching import BASECFM

class CustomCFM(BASECFM):
    def __init__(self, n_feats, cfm_params, n_spks=1, spk_emb_dim=128):
        super().__init__(n_feats, cfm_params, n_spks, spk_emb_dim)
        # Custom estimator initialization

Parameters

n_feats
int
required
Number of mel-spectrogram features
cfm_params
object
required
CFM parameters including solver configuration
n_spks
int
default:"1"
Number of speakers
spk_emb_dim
int
default:"128"
Speaker embedding dimension

Flow Matching Details

The conditional flow matching loss is computed as:
# Sample random timestep t ~ U(0, 1)
t = torch.rand([batch_size, 1, 1])

# Sample noise z ~ N(0, I)
z = torch.randn_like(x1)

# Interpolate between noise and target
y = (1 - (1 - sigma_min) * t) * z + t * x1

# Target flow (velocity)
u = x1 - (1 - sigma_min) * z

# Loss: MSE between predicted and target flow
loss = F.mse_loss(estimator(y, mask, mu, t, spks), u)

Example Usage

import torch
from matcha.models.components.flow_matching import CFM
from types import SimpleNamespace

# Configure CFM
cfm_params = SimpleNamespace(
    solver="euler",
    sigma_min=1e-4
)

decoder_params = {
    "channels": (256, 256),
    "dropout": 0.05,
    "attention_head_dim": 64,
    "n_blocks": 1,
    "num_mid_blocks": 2,
    "num_heads": 2,
    "act_fn": "snake"
}

# Create CFM module
cfm = CFM(
    in_channels=160,  # 2 * 80 (concatenated with encoder output)
    out_channel=80,
    cfm_params=cfm_params,
    decoder_params=decoder_params,
    n_spks=1,
    spk_emb_dim=64
)

# Inference example
mu = torch.randn(2, 80, 100)  # Encoder output
mask = torch.ones(2, 1, 100)

output = cfm(
    mu=mu,
    mask=mask,
    n_timesteps=10,
    temperature=0.667
)

print(f"Generated mel shape: {output.shape}")  # (2, 80, 100)

# Training example
target_mel = torch.randn(2, 80, 100)
loss, _ = cfm.compute_loss(
    x1=target_mel,
    mask=mask,
    mu=mu
)

print(f"CFM loss: {loss.item():.4f}")

Source Reference

Implementation: matcha/models/components/flow_matching.py:121

Build docs developers (and LLMs) love