Skip to main content

Overview

The Decoder is a 1D U-Net architecture that serves as the estimator network for the conditional flow matching process. It predicts the velocity field needed to transform noise into mel-spectrograms.

Class Definition

Decoder

from matcha.models.components.decoder import Decoder

decoder = Decoder(
    in_channels=160,
    out_channels=80,
    channels=(256, 256),
    dropout=0.05,
    attention_head_dim=64,
    n_blocks=1,
    num_mid_blocks=2,
    num_heads=4,
    act_fn="snake",
    down_block_type="transformer",
    mid_block_type="transformer",
    up_block_type="transformer"
)

Constructor Parameters

in_channels
int
required
Number of input channels (concatenated noise, mean, and optionally speaker embeddings)
out_channels
int
required
Number of output channels (mel-spectrogram features, typically 80)
channels
tuple
default:"(256, 256)"
Hidden channel dimensions for each level of the U-Net. Length determines depth
dropout
float
default:"0.05"
Dropout probability in transformer blocks
attention_head_dim
int
default:"64"
Dimension of each attention head
n_blocks
int
default:"1"
Number of transformer/conformer blocks at each level
num_mid_blocks
int
default:"2"
Number of middle blocks (bottleneck)
num_heads
int
default:"4"
Number of attention heads
act_fn
str
default:"snake"
Activation function: “snake”, “swish”, “mish”, “gelu”
down_block_type
str
default:"transformer"
Type of blocks in downsampling path: “transformer” or “conformer”
mid_block_type
str
default:"transformer"
Type of blocks in middle (bottleneck): “transformer” or “conformer”
up_block_type
str
default:"transformer"
Type of blocks in upsampling path: “transformer” or “conformer”

Methods

forward()

Forward pass through the U-Net decoder.
def forward(
    x: torch.Tensor,
    mask: torch.Tensor,
    mu: torch.Tensor,
    t: torch.Tensor,
    spks: torch.Tensor = None,
    cond: torch.Tensor = None
) -> torch.Tensor

Parameters

x
torch.Tensor
required
Noisy mel-spectrogram at timestep tShape: (batch_size, in_channels, time)
mask
torch.Tensor
required
Mask for valid time stepsShape: (batch_size, 1, time)
mu
torch.Tensor
required
Mean from encoder (conditioning)Shape: (batch_size, n_feats, time)
t
torch.Tensor
required
Current timestep (0 to 1) for flow matchingShape: (batch_size,)
spks
torch.Tensor
default:"None"
Speaker embeddings for multi-speaker modelsShape: (batch_size, spk_emb_dim)
cond
torch.Tensor
default:"None"
Additional conditioning (reserved for future use)

Returns

output
torch.Tensor
Predicted velocity field (flow direction)Shape: (batch_size, out_channels, time)

Architecture Components

ResnetBlock1D

Residual block with time embedding.
from matcha.models.components.decoder import ResnetBlock1D

resnet = ResnetBlock1D(
    dim=256,
    dim_out=256,
    time_emb_dim=1024,
    groups=8
)

Downsample1D

Downsampling layer using strided convolution.
from matcha.models.components.decoder import Downsample1D

downsample = Downsample1D(dim=256)

Upsample1D

Upsampling layer using transposed convolution or interpolation.
from matcha.models.components.decoder import Upsample1D

upsample = Upsample1D(
    channels=256,
    use_conv=False,
    use_conv_transpose=True,
    out_channels=256
)

Parameters

channels
int
required
Number of input channels
use_conv
bool
default:"False"
Use convolution after interpolation
use_conv_transpose
bool
default:"True"
Use transposed convolution for upsampling
out_channels
int
default:"None"
Number of output channels (defaults to input channels)

SinusoidalPosEmb

Sinusoidal positional embeddings for time encoding.
from matcha.models.components.decoder import SinusoidalPosEmb

time_emb = SinusoidalPosEmb(dim=256)

Parameters

dim
int
required
Embedding dimension (must be even)

TimestepEmbedding

MLP for processing time embeddings.
from matcha.models.components.decoder import TimestepEmbedding

time_mlp = TimestepEmbedding(
    in_channels=256,
    time_embed_dim=1024,
    act_fn="silu",
    out_dim=None,
    post_act_fn=None,
    cond_proj_dim=None
)

Parameters

in_channels
int
required
Input dimension from sinusoidal embeddings
time_embed_dim
int
required
Output embedding dimension
act_fn
str
default:"silu"
Activation function
out_dim
int
default:"None"
Optional different output dimension
post_act_fn
str
default:"None"
Optional activation after second linear layer
cond_proj_dim
int
default:"None"
Optional conditioning projection dimension

ConformerWrapper

Wrapper for Conformer blocks (alternative to Transformer).
from matcha.models.components.decoder import ConformerWrapper

conformer = ConformerWrapper(
    dim=256,
    dim_head=64,
    heads=4,
    ff_mult=4,
    conv_expansion_factor=2,
    conv_kernel_size=31,
    attn_dropout=0.1,
    ff_dropout=0.1,
    conv_dropout=0.1,
    conv_causal=False
)

U-Net Architecture

The decoder follows a U-Net structure:
  1. Time Embedding: Sinusoidal embeddings + MLP
  2. Input Processing: Concatenate x, mu, and optionally speaker embeddings
  3. Downsampling Path: ResNet blocks + Transformer/Conformer + Downsample
  4. Middle Blocks: Multiple ResNet + Transformer/Conformer blocks
  5. Upsampling Path: ResNet blocks + Transformer/Conformer + Upsample with skip connections
  6. Output: Final convolution to output channels
Input (x + mu + spks) → Time Embedding

[Down Block 1] → (skip connection) ────┐
    ↓                                    │
[Down Block 2] → (skip connection) ──┐  │
    ↓                                 │  │
[Middle Blocks]                       │  │
    ↓                                 │  │
[Up Block 1] ←───────────────────────┘  │
    ↓                                    │
[Up Block 2] ←──────────────────────────┘

Output (velocity field)

Example Usage

import torch
from matcha.models.components.decoder import Decoder

# Create decoder
decoder = Decoder(
    in_channels=160,  # 80 (x) + 80 (mu)
    out_channels=80,
    channels=(256, 256),
    dropout=0.05,
    attention_head_dim=64,
    n_blocks=1,
    num_mid_blocks=2,
    num_heads=2,
    act_fn="snake",
    down_block_type="transformer",
    mid_block_type="transformer",
    up_block_type="transformer"
)

# Example inputs
batch_size = 2
time_steps = 100

x = torch.randn(batch_size, 80, time_steps)  # Noisy mel
mu = torch.randn(batch_size, 80, time_steps)  # Encoder output
mask = torch.ones(batch_size, 1, time_steps)
t = torch.rand(batch_size)  # Timesteps

# Forward pass
velocity = decoder(
    x=x,
    mask=mask,
    mu=mu,
    t=t,
    spks=None,
    cond=None
)

print(f"Output velocity shape: {velocity.shape}")  # (2, 80, 100)

Configuration Examples

Lightweight Configuration

decoder = Decoder(
    in_channels=160,
    out_channels=80,
    channels=(128, 256),  # Smaller channels
    n_blocks=1,
    num_mid_blocks=1,
    num_heads=2,
    dropout=0.1
)

Heavy Configuration

decoder = Decoder(
    in_channels=160,
    out_channels=80,
    channels=(256, 512, 512),  # Deeper network
    n_blocks=2,  # More blocks per level
    num_mid_blocks=4,
    num_heads=8,
    dropout=0.05
)

Conformer-based Configuration

decoder = Decoder(
    in_channels=160,
    out_channels=80,
    channels=(256, 256),
    down_block_type="conformer",
    mid_block_type="conformer",
    up_block_type="conformer",
    n_blocks=1,
    num_heads=4
)

Source Reference

Implementation: matcha/models/components/decoder.py:200

Build docs developers (and LLMs) love