Skip to main content

Whisper

The main Whisper model class that performs speech recognition. Inherits from torch.nn.Module.

Constructor

Whisper(dims: ModelDimensions)
dims
ModelDimensions
required
Model dimensions configuration containing architecture parameters

Attributes

dims
ModelDimensions
The model dimensions used to initialize the encoder and decoder
encoder
AudioEncoder
The audio encoder that processes mel spectrograms into features
decoder
TextDecoder
The text decoder that generates text tokens from audio features
alignment_heads
torch.Tensor
Sparse tensor indicating which attention heads to use for time alignment. By default, uses the last half of decoder layers.

Properties

device
torch.device
The device (CPU/GPU) where the model parameters are stored
is_multilingual
bool
Returns True if the model supports multiple languages (vocab size >= 51865)
num_languages
int
Number of languages supported by the model

Methods

embed_audio

def embed_audio(mel: torch.Tensor) -> torch.Tensor
Encode audio mel spectrogram into features.
mel
torch.Tensor
Mel spectrogram with shape (batch_size, n_mels, n_ctx)
return
torch.Tensor
Encoded audio features with shape (batch_size, n_audio_ctx, n_audio_state)

logits

def logits(tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor
Compute logits for next token prediction.
tokens
torch.Tensor
Text tokens with shape (batch_size, seq_len)
audio_features
torch.Tensor
Encoded audio features from embed_audio()
return
torch.Tensor
Logits for next token with shape (batch_size, seq_len, vocab_size)

forward

def forward(mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]
Full forward pass through encoder and decoder.
mel
torch.Tensor
Mel spectrogram input
tokens
torch.Tensor
Text tokens
return
torch.Tensor
Decoder output logits

set_alignment_heads

def set_alignment_heads(dump: bytes) -> None
Set custom alignment heads for cross-attention analysis.
dump
bytes
Base85-encoded, gzip-compressed boolean array specifying which attention heads to use

install_kv_cache_hooks

def install_kv_cache_hooks(cache: Optional[dict] = None) -> Tuple[Dict, List]
Install hooks for key-value caching to speed up autoregressive decoding.
cache
Optional[dict]
Existing cache dictionary to extend, or None to create new cache
cache
Dict[nn.Module, torch.Tensor]
Dictionary mapping key/value projection modules to cached tensors
hooks
List[RemovableHandle]
List of PyTorch hook handles that can be used to remove the hooks

transcribe

def transcribe(
    audio: Union[str, np.ndarray, torch.Tensor],
    **kwargs
) -> dict
Transcribe audio to text with timestamps and metadata. See the transcribe documentation for details.

decode

def decode(
    mel: torch.Tensor,
    options: DecodingOptions = DecodingOptions(),
    **kwargs
) -> Union[DecodingResult, List[DecodingResult]]
Decode mel spectrogram(s) to text. See the decode documentation for details.

detect_language

def detect_language(
    mel: Tensor,
    tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]
Detect the spoken language in audio. Returns language tokens and probability distributions.

ModelDimensions

Dataclass containing model architecture dimensions.
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int

Fields

n_mels
int
required
Number of mel filterbank channels (typically 80)
n_audio_ctx
int
required
Audio context length - number of frames in the encoder
n_audio_state
int
required
Hidden dimension size of the audio encoder
n_audio_head
int
required
Number of attention heads in the audio encoder
n_audio_layer
int
required
Number of layers in the audio encoder
n_vocab
int
required
Vocabulary size - determines if model is multilingual (>= 51865)
n_text_ctx
int
required
Text context length - maximum sequence length for the decoder
n_text_state
int
required
Hidden dimension size of the text decoder
n_text_head
int
required
Number of attention heads in the text decoder
n_text_layer
int
required
Number of layers in the text decoder

Usage Examples

Loading a Model

import whisper

# Load pre-trained model
model = whisper.load_model("base")

# Check model properties
print(f"Is multilingual: {model.is_multilingual}")
print(f"Number of languages: {model.num_languages}")
print(f"Device: {model.device}")

Using Model Components

import torch
from whisper.audio import log_mel_spectrogram

# Load and process audio
mel = log_mel_spectrogram(audio)
mel = mel.to(model.device)

# Encode audio
audio_features = model.embed_audio(mel)

# Prepare tokens and get logits
tokens = torch.tensor([[model.dims.n_vocab - 1]])  # Start token
logits = model.logits(tokens, audio_features)

Creating Custom Model

from whisper.model import Whisper, ModelDimensions

# Define custom dimensions (tiny model example)
dims = ModelDimensions(
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=384,
    n_audio_head=6,
    n_audio_layer=4,
    n_vocab=51864,
    n_text_ctx=448,
    n_text_state=384,
    n_text_head=6,
    n_text_layer=4,
)

# Create model instance
custom_model = Whisper(dims)

Notes

  • The Whisper class is typically loaded using whisper.load_model() rather than instantiated directly
  • Models are available in sizes: tiny, base, small, medium, large
  • Multilingual models have vocab size >= 51865 and support 99 languages
  • English-only models are more accurate for English but don’t support other languages
  • Use model.to(device) to move the model to GPU for faster inference
  • The encoder processes 30-second audio chunks at a time
  • KV cache hooks are used internally by the decode function for efficiency

Build docs developers (and LLMs) love