Skip to main content

Text Encoder

The text encoder is responsible for converting phoneme sequences into continuous representations and predicting the duration (number of mel-spectrogram frames) for each phoneme. It consists of two main components:
  1. Encoder: Transformer-based network that processes phoneme embeddings
  2. Duration Predictor: Predicts how long each phoneme should be spoken

Architecture Overview

Phoneme IDs

[Embedding Layer]

[Optional PreNet] (3 Conv-ReLU-Norm layers)

[Concatenate Speaker Embedding] (if multi-speaker)

[Transformer Encoder] (N layers)

     ├──> [Projection: proj_m] ──> mu (mean mel features)
     └──> [Duration Predictor] ──> logw (log durations)

TextEncoder Class

The main encoder class is defined in text_encoder.py:328:
class TextEncoder(nn.Module):
    def __init__(
        self,
        encoder_type,              # Type of encoder (e.g., "transformer")
        encoder_params,            # Encoder hyperparameters
        duration_predictor_params, # Duration predictor hyperparameters
        n_vocab,                   # Vocabulary size (phoneme count)
        n_spks=1,                  # Number of speakers
        spk_emb_dim=128,          # Speaker embedding dimension
    )
Key Attributes:
  • n_vocab: Size of phoneme vocabulary
  • n_feats: Number of mel-spectrogram features (typically 80)
  • n_channels: Hidden dimension of encoder (typically 192)
  • n_spks: Number of speakers (1 for single-speaker)
  • spk_emb_dim: Dimension of speaker embeddings

Component Breakdown

1. Embedding Layer

Converts discrete phoneme IDs to continuous vectors (text_encoder.py:346-347):
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
  • Embedding dimension: n_channels (typically 192)
  • Initialized with normal distribution: N(0, 1/√n_channels)
  • Standard practice from Transformer literature

2. PreNet (Optional)

An optional convolutional pre-processing network (text_encoder.py:349-359):
if encoder_params.prenet:
    self.prenet = ConvReluNorm(
        self.n_channels,
        self.n_channels,
        self.n_channels,
        kernel_size=5,
        n_layers=3,
        p_dropout=0.5,
    )
else:
    self.prenet = lambda x, x_mask: x
ConvReluNorm (text_encoder.py:36) is a stack of:
  • Conv1D → LayerNorm → ReLU → Dropout
  • Repeated for n_layers (typically 3)
  • Residual connection at the end
The PreNet adds local context modeling before the transformer encoder. The high dropout rate (0.5) acts as a regularizer and helps with generalization.

3. Transformer Encoder

The core encoder processes the phoneme sequence (text_encoder.py:361-368):
self.encoder = Encoder(
    encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
    encoder_params.filter_channels,
    encoder_params.n_heads,
    encoder_params.n_layers,
    encoder_params.kernel_size,
    encoder_params.p_dropout,
)
Typical Configuration:
  • n_channels: 192 (base hidden dimension)
  • filter_channels: 768 (FFN inner dimension, 4× hidden size)
  • n_heads: 2 (attention heads)
  • n_layers: 6 (transformer blocks)
  • kernel_size: 3 (for FFN convolutions)
  • p_dropout: 0.1

Encoder Layer Structure

Each transformer layer (text_encoder.py:276) consists of:
for _ in range(self.n_layers):
    # Multi-head self-attention with RoPE
    self.attn_layers.append(
        MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
    )
    self.norm_layers_1.append(LayerNorm(hidden_channels))
    
    # Feed-forward network
    self.ffn_layers.append(
        FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
    )
    self.norm_layers_2.append(LayerNorm(hidden_channels))
Forward pass (text_encoder.py:314-325):
def forward(self, x, x_mask):
    attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
    for i in range(self.n_layers):
        x = x * x_mask
        # Self-attention
        y = self.attn_layers[i](x, x, attn_mask)
        y = self.drop(y)
        x = self.norm_layers_1[i](x + y)  # Post-norm + residual
        # Feed-forward
        y = self.ffn_layers[i](x, x_mask)
        y = self.drop(y)
        x = self.norm_layers_2[i](x + y)  # Post-norm + residual
    return x * x_mask
This uses post-normalization (residual first, then normalize), which is more stable than pre-normalization for smaller models.

4. Rotary Position Embeddings (RoPE)

Instead of absolute position embeddings, the encoder uses RoPE (text_encoder.py:97):
class RotaryPositionalEmbeddings(nn.Module):
    def __init__(self, d: int, base: int = 10_000):
        # Θ = {θᵢ = 10000^(-2(i-1)/d), i ∈ [1, 2, ..., d/2]}
RoPE rotates query and key vectors based on their position:
  • Applied in MultiHeadAttention (text_encoder.py:232-233)
  • Only applied to half of the features
  • Encodes relative positions rather than absolute
Benefits:
  • Better extrapolation to longer sequences
  • Encodes relative position information
  • No learned parameters

5. Multi-Head Attention

Custom attention implementation with RoPE (text_encoder.py:175):
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        channels,           # Input/output channels
        out_channels,       # Output channels
        n_heads,           # Number of attention heads (typically 2)
        p_dropout=0.0,
    )
Attention computation (text_encoder.py:226-246):
def attention(self, query, key, value, mask=None):
    # Reshape to multi-head format
    query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
    key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
    value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
    
    # Apply rotary position embeddings
    query = self.query_rotary_pe(query)
    key = self.key_rotary_pe(key)
    
    # Scaled dot-product attention
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)
    
    p_attn = torch.nn.functional.softmax(scores, dim=-1)
    p_attn = self.drop(p_attn)
    output = torch.matmul(p_attn, value)
    
    return output, p_attn

6. Feed-Forward Network

Position-wise FFN with convolutions (text_encoder.py:255):
class FFN(nn.Module):
    def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
        self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
        self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.drop = torch.nn.Dropout(p_dropout)
    
    def forward(self, x, x_mask):
        x = self.conv_1(x * x_mask)
        x = torch.relu(x)
        x = self.drop(x)
        x = self.conv_2(x * x_mask)
        return x * x_mask
Typically: in_channelsfilter_channels (4×) → out_channels Example: 192 → 768 → 192

7. Output Projections

Two separate projections from encoder hidden state (text_encoder.py:370-376): Mean Projection (mel-spectrogram features):
self.proj_m = torch.nn.Conv1d(
    self.n_channels + (spk_emb_dim if n_spks > 1 else 0), 
    self.n_feats,  # Typically 80
    1
)
Duration Predictor:
self.proj_w = DurationPredictor(
    self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
    duration_predictor_params.filter_channels_dp,
    duration_predictor_params.kernel_size,
    duration_predictor_params.p_dropout,
)

Duration Predictor

Predicts log-scaled duration for each phoneme (text_encoder.py:70):
class DurationPredictor(nn.Module):
    def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
        self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
        self.norm_1 = LayerNorm(filter_channels)
        self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
        self.norm_2 = LayerNorm(filter_channels)
        self.proj = torch.nn.Conv1d(filter_channels, 1, 1)  # Output: log duration
Architecture:
Input (encoder hidden state)

Conv1D → ReLU → LayerNorm → Dropout

Conv1D → ReLU → LayerNorm → Dropout

Conv1D (1×1) → log(duration)
Forward pass (text_encoder.py:84-94):
def forward(self, x, x_mask):
    x = self.conv_1(x * x_mask)
    x = torch.relu(x)
    x = self.norm_1(x)
    x = self.drop(x)
    x = self.conv_2(x * x_mask)
    x = torch.relu(x)
    x = self.norm_2(x)
    x = self.drop(x)
    x = self.proj(x * x_mask)
    return x * x_mask
Why predict log-duration?
  • Durations span several orders of magnitude (1 frame to 100+ frames)
  • Log scale makes the prediction more stable
  • During inference: duration = exp(logw) to get actual frame counts

Forward Pass

The complete forward pass (text_encoder.py:378):
def forward(self, x, x_lengths, spks=None):
    # 1. Embed phonemes and scale
    x = self.emb(x) * math.sqrt(self.n_channels)
    x = torch.transpose(x, 1, -1)  # (B, L) -> (B, C, L)
    
    # 2. Create mask
    x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
    
    # 3. Apply PreNet
    x = self.prenet(x, x_mask)
    
    # 4. Concatenate speaker embedding (multi-speaker)
    if self.n_spks > 1:
        x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
    
    # 5. Transformer encoder
    x = self.encoder(x, x_mask)
    
    # 6. Project to mel features
    mu = self.proj_m(x) * x_mask
    
    # 7. Predict durations (detached from gradient)
    x_dp = torch.detach(x)
    logw = self.proj_w(x_dp, x_mask)
    
    return mu, logw, x_mask
Outputs:
  • mu: Mean mel-spectrogram features, shape (B, n_feats, L)
  • logw: Log-durations for each phoneme, shape (B, 1, L)
  • x_mask: Sequence mask, shape (B, 1, L)
The duration predictor input is detached from gradients (torch.detach(x)) at line text_encoder.py:407. This prevents gradients from the duration loss from affecting the encoder representations, as the duration is supervised separately via MAS.

Multi-Speaker Support

For multi-speaker models, speaker embeddings are concatenated to the encoder output:
if self.n_spks > 1:
    # Repeat speaker embedding across time dimension
    x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
This increases the encoder channel dimension:
  • Single-speaker: n_channels
  • Multi-speaker: n_channels + spk_emb_dim
The projections and duration predictor account for this increased dimension.

Layer Normalization

Custom LayerNorm implementation (text_encoder.py:15):
class LayerNorm(nn.Module):
    def __init__(self, channels, eps=1e-4):
        self.gamma = torch.nn.Parameter(torch.ones(channels))
        self.beta = torch.nn.Parameter(torch.zeros(channels))
    
    def forward(self, x):
        mean = torch.mean(x, 1, keepdim=True)
        variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
        x = (x - mean) * torch.rsqrt(variance + self.eps)
        x = x * self.gamma.view(*shape) + self.beta.view(*shape)
        return x
Normalizes over the channel dimension, preserving batch and time dimensions.

Masking

Sequence masking is applied throughout to handle variable-length inputs:
x_mask = sequence_mask(x_lengths, x.size(2))  # (B, L)
x_mask = x_mask.unsqueeze(1)                  # (B, 1, L)
All operations multiply by mask: x * x_mask to zero out padding.

Embedding Scaling

Phoneme embeddings are scaled by √n_channels (text_encoder.py:397):
x = self.emb(x) * math.sqrt(self.n_channels)
This is standard practice from “Attention Is All You Need” to prevent embedding values from being too small relative to positional encodings.

Typical Hyperparameters

ParameterTypical ValueDescription
n_vocab150-200Phoneme vocabulary size
n_channels192Encoder hidden dimension
n_feats80Mel-spectrogram features
filter_channels768FFN inner dimension (4× hidden)
n_heads2Number of attention heads
n_layers6Number of transformer blocks
kernel_size3Convolution kernel size
p_dropout0.1Dropout probability
filter_channels_dp256Duration predictor hidden dim

Build docs developers (and LLMs) love