Text Encoder
The text encoder is responsible for converting phoneme sequences into continuous representations and predicting the duration (number of mel-spectrogram frames) for each phoneme. It consists of two main components:
- Encoder: Transformer-based network that processes phoneme embeddings
- Duration Predictor: Predicts how long each phoneme should be spoken
Architecture Overview
Phoneme IDs
↓
[Embedding Layer]
↓
[Optional PreNet] (3 Conv-ReLU-Norm layers)
↓
[Concatenate Speaker Embedding] (if multi-speaker)
↓
[Transformer Encoder] (N layers)
↓
├──> [Projection: proj_m] ──> mu (mean mel features)
└──> [Duration Predictor] ──> logw (log durations)
TextEncoder Class
The main encoder class is defined in text_encoder.py:328:
class TextEncoder(nn.Module):
def __init__(
self,
encoder_type, # Type of encoder (e.g., "transformer")
encoder_params, # Encoder hyperparameters
duration_predictor_params, # Duration predictor hyperparameters
n_vocab, # Vocabulary size (phoneme count)
n_spks=1, # Number of speakers
spk_emb_dim=128, # Speaker embedding dimension
)
Key Attributes:
n_vocab: Size of phoneme vocabulary
n_feats: Number of mel-spectrogram features (typically 80)
n_channels: Hidden dimension of encoder (typically 192)
n_spks: Number of speakers (1 for single-speaker)
spk_emb_dim: Dimension of speaker embeddings
Component Breakdown
1. Embedding Layer
Converts discrete phoneme IDs to continuous vectors (text_encoder.py:346-347):
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
- Embedding dimension:
n_channels (typically 192)
- Initialized with normal distribution:
N(0, 1/√n_channels)
- Standard practice from Transformer literature
2. PreNet (Optional)
An optional convolutional pre-processing network (text_encoder.py:349-359):
if encoder_params.prenet:
self.prenet = ConvReluNorm(
self.n_channels,
self.n_channels,
self.n_channels,
kernel_size=5,
n_layers=3,
p_dropout=0.5,
)
else:
self.prenet = lambda x, x_mask: x
ConvReluNorm (text_encoder.py:36) is a stack of:
- Conv1D → LayerNorm → ReLU → Dropout
- Repeated for
n_layers (typically 3)
- Residual connection at the end
The PreNet adds local context modeling before the transformer encoder. The high dropout rate (0.5) acts as a regularizer and helps with generalization.
The core encoder processes the phoneme sequence (text_encoder.py:361-368):
self.encoder = Encoder(
encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
encoder_params.filter_channels,
encoder_params.n_heads,
encoder_params.n_layers,
encoder_params.kernel_size,
encoder_params.p_dropout,
)
Typical Configuration:
n_channels: 192 (base hidden dimension)
filter_channels: 768 (FFN inner dimension, 4× hidden size)
n_heads: 2 (attention heads)
n_layers: 6 (transformer blocks)
kernel_size: 3 (for FFN convolutions)
p_dropout: 0.1
Encoder Layer Structure
Each transformer layer (text_encoder.py:276) consists of:
for _ in range(self.n_layers):
# Multi-head self-attention with RoPE
self.attn_layers.append(
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
# Feed-forward network
self.ffn_layers.append(
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
Forward pass (text_encoder.py:314-325):
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.n_layers):
x = x * x_mask
# Self-attention
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y) # Post-norm + residual
# Feed-forward
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y) # Post-norm + residual
return x * x_mask
This uses post-normalization (residual first, then normalize), which is more stable than pre-normalization for smaller models.
4. Rotary Position Embeddings (RoPE)
Instead of absolute position embeddings, the encoder uses RoPE (text_encoder.py:97):
class RotaryPositionalEmbeddings(nn.Module):
def __init__(self, d: int, base: int = 10_000):
# Θ = {θᵢ = 10000^(-2(i-1)/d), i ∈ [1, 2, ..., d/2]}
RoPE rotates query and key vectors based on their position:
- Applied in
MultiHeadAttention (text_encoder.py:232-233)
- Only applied to half of the features
- Encodes relative positions rather than absolute
Benefits:
- Better extrapolation to longer sequences
- Encodes relative position information
- No learned parameters
5. Multi-Head Attention
Custom attention implementation with RoPE (text_encoder.py:175):
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels, # Input/output channels
out_channels, # Output channels
n_heads, # Number of attention heads (typically 2)
p_dropout=0.0,
)
Attention computation (text_encoder.py:226-246):
def attention(self, query, key, value, mask=None):
# Reshape to multi-head format
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
# Apply rotary position embeddings
query = self.query_rotary_pe(query)
key = self.key_rotary_pe(key)
# Scaled dot-product attention
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1)
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
return output, p_attn
6. Feed-Forward Network
Position-wise FFN with convolutions (text_encoder.py:255):
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
Typically: in_channels → filter_channels (4×) → out_channels
Example: 192 → 768 → 192
7. Output Projections
Two separate projections from encoder hidden state (text_encoder.py:370-376):
Mean Projection (mel-spectrogram features):
self.proj_m = torch.nn.Conv1d(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
self.n_feats, # Typically 80
1
)
Duration Predictor:
self.proj_w = DurationPredictor(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
duration_predictor_params.filter_channels_dp,
duration_predictor_params.kernel_size,
duration_predictor_params.p_dropout,
)
Duration Predictor
Predicts log-scaled duration for each phoneme (text_encoder.py:70):
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1) # Output: log duration
Architecture:
Input (encoder hidden state)
↓
Conv1D → ReLU → LayerNorm → Dropout
↓
Conv1D → ReLU → LayerNorm → Dropout
↓
Conv1D (1×1) → log(duration)
Forward pass (text_encoder.py:84-94):
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
Why predict log-duration?
- Durations span several orders of magnitude (1 frame to 100+ frames)
- Log scale makes the prediction more stable
- During inference:
duration = exp(logw) to get actual frame counts
Forward Pass
The complete forward pass (text_encoder.py:378):
def forward(self, x, x_lengths, spks=None):
# 1. Embed phonemes and scale
x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1) # (B, L) -> (B, C, L)
# 2. Create mask
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
# 3. Apply PreNet
x = self.prenet(x, x_mask)
# 4. Concatenate speaker embedding (multi-speaker)
if self.n_spks > 1:
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
# 5. Transformer encoder
x = self.encoder(x, x_mask)
# 6. Project to mel features
mu = self.proj_m(x) * x_mask
# 7. Predict durations (detached from gradient)
x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask
Outputs:
mu: Mean mel-spectrogram features, shape (B, n_feats, L)
logw: Log-durations for each phoneme, shape (B, 1, L)
x_mask: Sequence mask, shape (B, 1, L)
The duration predictor input is detached from gradients (torch.detach(x)) at line text_encoder.py:407. This prevents gradients from the duration loss from affecting the encoder representations, as the duration is supervised separately via MAS.
Multi-Speaker Support
For multi-speaker models, speaker embeddings are concatenated to the encoder output:
if self.n_spks > 1:
# Repeat speaker embedding across time dimension
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
This increases the encoder channel dimension:
- Single-speaker:
n_channels
- Multi-speaker:
n_channels + spk_emb_dim
The projections and duration predictor account for this increased dimension.
Layer Normalization
Custom LayerNorm implementation (text_encoder.py:15):
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-4):
self.gamma = torch.nn.Parameter(torch.ones(channels))
self.beta = torch.nn.Parameter(torch.zeros(channels))
def forward(self, x):
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
Normalizes over the channel dimension, preserving batch and time dimensions.
Masking
Sequence masking is applied throughout to handle variable-length inputs:
x_mask = sequence_mask(x_lengths, x.size(2)) # (B, L)
x_mask = x_mask.unsqueeze(1) # (B, 1, L)
All operations multiply by mask: x * x_mask to zero out padding.
Embedding Scaling
Phoneme embeddings are scaled by √n_channels (text_encoder.py:397):
x = self.emb(x) * math.sqrt(self.n_channels)
This is standard practice from “Attention Is All You Need” to prevent embedding values from being too small relative to positional encodings.
Typical Hyperparameters
| Parameter | Typical Value | Description |
|---|
n_vocab | 150-200 | Phoneme vocabulary size |
n_channels | 192 | Encoder hidden dimension |
n_feats | 80 | Mel-spectrogram features |
filter_channels | 768 | FFN inner dimension (4× hidden) |
n_heads | 2 | Number of attention heads |
n_layers | 6 | Number of transformer blocks |
kernel_size | 3 | Convolution kernel size |
p_dropout | 0.1 | Dropout probability |
filter_channels_dp | 256 | Duration predictor hidden dim |