Skip to main content

Matcha-TTS Architecture

Matcha-TTS is a fast, probabilistic, non-autoregressive text-to-speech model that uses Conditional Flow Matching (CFM) for mel-spectrogram generation. The architecture consists of three main components:
  1. Text Encoder - Encodes phoneme sequences and predicts durations
  2. Duration Aligner - Uses Monotonic Alignment Search (MAS) for text-to-speech alignment
  3. Conditional Flow Matching Decoder - Generates high-quality mel-spectrograms

High-Level Architecture

Text Input (phonemes)

[Text Encoder]
   ↓         ↓
  mu_x     logw (durations)
   ↓         ↓
   └────[Alignment]────┐
            ↓          │
          mu_y    [Duration Loss]
            ↓          │
    [CFM Decoder]      │
            ↓          │
     Mel Spectrogram   │
            ↓          │
     [Prior Loss]      │
            ↓          │
      [CFM Loss]←──────┘

Core Components

MatchaTTS Class

The main model class defined in matcha_tts.py:23 brings together all components:
class MatchaTTS(BaseLightningClass):
    def __init__(
        self,
        n_vocab,      # Vocabulary size (phoneme count)
        n_spks,       # Number of speakers
        spk_emb_dim,  # Speaker embedding dimension
        n_feats,      # Number of mel frequency bins (80)
        encoder,      # Text encoder config
        decoder,      # Decoder config
        cfm,          # Flow matching config
        ...
    )
Key Parameters:
  • n_vocab: Size of phoneme vocabulary
  • n_spks: Number of speakers (1 for single-speaker models)
  • spk_emb_dim: Dimension of speaker embeddings (default: 64)
  • n_feats: Number of mel-spectrogram features (typically 80)
  • out_size: Segment size for training (enables larger batch sizes)

Model Initialization

The model initializes three key components in matcha_tts.py:55-71: 1. Speaker Embedding (Multi-speaker only)
if n_spks > 1:
    self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
2. Text Encoder
self.encoder = TextEncoder(
    encoder.encoder_type,
    encoder.encoder_params,
    encoder.duration_predictor_params,
    n_vocab,
    n_spks,
    spk_emb_dim,
)
3. CFM Decoder
self.decoder = CFM(
    in_channels=2 * encoder.encoder_params.n_feats,
    out_channel=encoder.encoder_params.n_feats,
    cfm_params=cfm,
    decoder_params=decoder,
    n_spks=n_spks,
    spk_emb_dim=spk_emb_dim,
)
The decoder input has 2 * n_feats channels because it concatenates the encoder output mu_y with the noisy sample during flow matching.

Training Process

The training forward pass (matcha_tts.py:153) computes three losses:

1. Duration Loss

Compares predicted durations with those extracted by Monotonic Alignment Search:
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
Calculated in model.py:44:
def duration_loss(logw, logw_, lengths):
    loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
    return loss

2. Prior Loss

Measures the distance between encoder outputs and target mel-spectrograms (matcha_tts.py:239-242):
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
This encourages the encoder to produce outputs close to the target mel-spectrograms.

3. Flow Matching Loss

The main reconstruction loss from the CFM decoder (matcha_tts.py:237):
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
The total training loss is: Total Loss = Duration Loss + Prior Loss + Flow Matching Loss

Inference Process

The synthesise method (matcha_tts.py:76) generates mel-spectrograms from text: Step 1: Encode text and predict durations
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
Step 2: Convert log-durations to frame counts
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale  # length_scale controls speech rate
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
Step 3: Generate alignment map
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
Step 4: Align encoded text to mel-spectrogram length
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
Step 5: Generate mel-spectrogram via flow matching
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
Key Inference Parameters:
  • n_timesteps: Number of ODE solver steps (10-50, higher = better quality but slower)
  • temperature: Controls variance (1.0 = normal, greater than 1.0 = more diverse, less than 1.0 = more deterministic)
  • length_scale: Speech rate control (greater than 1.0 = slower, less than 1.0 = faster)

Monotonic Alignment Search (MAS)

MAS finds the optimal alignment between text and mel-spectrogram during training (matcha_tts.py:189-198):
# Compute log-likelihood of alignment
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const

# Find maximum probability path
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
This assumes Gaussian distributions and finds the monotonic alignment that maximizes the likelihood.

Segment-based Training

To enable larger batch sizes, the model can train on random segments (matcha_tts.py:208-230):
if not isinstance(out_size, type(None)):
    max_offset = (y_lengths - out_size).clamp(0)
    # Randomly sample offset for each item in batch
    out_offset = torch.LongTensor([random.choice(range(start, end)) 
                                    for start, end in offset_ranges])
    # Extract segments
    y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
    attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
This “hack” from Grad-TTS allows training with larger batch sizes on limited GPU memory.

Output Format

The synthesise method returns a dictionary with:
{
    "encoder_outputs": torch.Tensor,  # shape: (B, n_feats, T)
    "decoder_outputs": torch.Tensor,  # shape: (B, n_feats, T)  
    "attn": torch.Tensor,             # shape: (B, max_text_len, T)
    "mel": torch.Tensor,              # Denormalized mel-spectrogram
    "mel_lengths": torch.Tensor,      # shape: (B,)
    "rtf": float,                     # Real-time factor
}
Real-Time Factor (RTF) measures inference speed. RTF < 1.0 means faster than real-time. Calculated as:
rtf = inference_time * 22050 / (mel_frames * 256)
where 22050 is the sample rate and 256 is the hop length.

Multi-Speaker Support

For multi-speaker models, speaker embeddings are added to both encoder and decoder:
if self.n_spks > 1:
    spks = self.spk_emb(spks.long())  # Get speaker embedding
    # Concatenated to encoder output and decoder input
Speaker embeddings are repeated across the time dimension and concatenated to the feature channels.

Build docs developers (and LLMs) love