Whisper uses a Transformer-based encoder-decoder architecture optimized for speech recognition tasks. Understanding the architecture helps with model customization, debugging, and advanced use cases.
Each Whisper model is defined by a ModelDimensions dataclass that specifies the architecture parameters:
whisper/model.py
@dataclassclass ModelDimensions: n_mels: int # Number of mel frequency bins n_audio_ctx: int # Audio context length n_audio_state: int # Audio encoder hidden dimension n_audio_head: int # Number of audio encoder attention heads n_audio_layer: int # Number of audio encoder layers n_vocab: int # Vocabulary size n_text_ctx: int # Text context length n_text_state: int # Text decoder hidden dimension n_text_head: int # Number of text decoder attention heads n_text_layer: int # Number of text decoder layers
These dimensions are automatically loaded from the model checkpoint and accessible via model.dims.
Convolution Layers: Two 1D convolutions with GELU activation
First conv: kernel_size=3, padding=1 (preserves length)
Second conv: kernel_size=3, stride=2 (downsamples by 2x)
Positional Encoding: Sinusoidal embeddings added to features
Transformer Blocks: Self-attention and feedforward layers
Layer Normalization: Final normalization layer
def forward(self, x: Tensor): """ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) the mel spectrogram of the audio """ x = F.gelu(self.conv1(x)) x = F.gelu(self.conv2(x)) x = x.permute(0, 2, 1) assert x.shape[1:] == self.positional_embedding.shape x = (x + self.positional_embedding).to(x.dtype) for block in self.blocks: x = block(x) x = self.ln_post(x) return x
Self-Attention: Attend to previous tokens (causal)
Cross-Attention: Attend to encoder outputs
Feedforward: Process through MLP
Output Projection: Project to vocabulary logits
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): """ x : torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) the encoded audio features to be attended on """ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 x = ( self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] ) x = x.to(xa.dtype) for block in self.blocks: x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = self.ln(x) logits = ( x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) ).float() return logits
Whisper uses specific attention heads for word-level timestamp alignment:
# Set alignment heads from model checkpointmodel.set_alignment_heads(alignment_heads_bytes)# Default: use last half of decoder layersall_heads = torch.zeros( self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)all_heads[self.dims.n_text_layer // 2 :] = Trueself.register_buffer("alignment_heads", all_heads.to_sparse())