Skip to main content

Overview

Wav2Vec2LlamaModel combines a Wav2Vec2 encoder with a Llama decoder for automatic speech recognition. It supports three model variants:
  • LLM_ASR: Standard encoder-decoder ASR
  • LLM_ASR_LID: ASR with language identification conditioning
  • ZERO_SHOT: Zero-shot learning with context examples

Constructor

Wav2Vec2LlamaModel(
    model_type: ModelType,
    model_dim: int,
    encoder_frontend: Wav2Vec2Frontend,
    encoder: TransformerEncoder,
    encoder_proj: nn.Module,
    text_frontend: StandardEmbedding,
    llama_decoder: TransformerLMDecoder,
    final_proj: nn.Module,
    target_vocab_info: VocabularyInfo,
    *,
    masker: Wav2Vec2Masker | None = None,
    max_generation_length: int = 8192,
    encoder_stacking: int = 1,
    lang_embeddings_p: float = 0.0,
    language_column_name: str = "lang",
    lang_embeddings: StandardEmbedding | None = None,
    lang_mapping: dict[str, int] | None = None,
    context_text_only: bool = False,
    beam_search_config: Wav2Vec2LlamaBeamSearchConfig = ...,
    streaming_config: Wav2Vec2LlamaStreamingConfig = ...,
    text_encoder: TokenEncoder | None = None,
    n_context_examples: int = 0,
    seed: int = 42
)

Core Parameters

model_type
ModelType
required
Model variant:
  • ModelType.LLM_ASR: Standard ASR
  • ModelType.LLM_ASR_LID: ASR with language ID
  • ModelType.ZERO_SHOT: Zero-shot with context
model_dim
int
required
Model dimension of the transformer decoder.
encoder_frontend
Wav2Vec2Frontend
required
Wav2Vec2 encoder frontend for feature extraction.
encoder
TransformerEncoder
required
Wav2Vec2 encoder.
encoder_proj
nn.Module
required
Projection layer from encoder outputs to decoder dimension.
text_frontend
StandardEmbedding
required
Text token embedding module.
llama_decoder
TransformerLMDecoder
required
Llama decoder-only model.
final_proj
nn.Module
required
Final projection layer from decoder to vocabulary logits.
target_vocab_info
VocabularyInfo
required
Vocabulary information including size and special token indices.

Optional Parameters

masker
Wav2Vec2Masker | None
default:"None"
Feature masker for Wav2Vec2 (used during training).
max_generation_length
int
default:"8192"
Maximum length of generated sequences in decoder.
encoder_stacking
int
default:"1"
Number of encoder frames to stack before feeding to decoder (for compression).
lang_embeddings_p
float
default:"0.0"
Probability of using language embeddings (for LID model). Dropout probability during training.
language_column_name
str
default:"lang"
Name of the batch metadata field containing language information.
lang_embeddings
StandardEmbedding | None
default:"None"
Language embedding module (required for LID model).
lang_mapping
dict[str, int] | None
default:"None"
Mapping from language codes to embedding indices.
context_text_only
bool
default:"False"
Whether to use text-only context (instead of audio+text).
beam_search_config
Wav2Vec2LlamaBeamSearchConfig
default:"Wav2Vec2LlamaBeamSearchConfig()"
Beam search configuration for decoding.
streaming_config
Wav2Vec2LlamaStreamingConfig
default:"Wav2Vec2LlamaStreamingConfig()"
Streaming configuration for >30s audio.
text_encoder
TokenEncoder | None
default:"None"
Text encoder for streaming mode.
n_context_examples
int
default:"0"
Number of context examples for zero-shot model.
seed
int
default:"42"
Random seed for reproducibility.
Models are typically loaded using load_model("omniASR_LLM_7B") rather than constructed directly.

Forward Pass

model.forward(
    batch: Seq2SeqBatch,
    return_logits: bool = False,
    return_decoder_inputs: bool = False
) -> Tensor | Tuple[...]
batch
Seq2SeqBatch
required
Input batch containing source audio and target text.
return_logits
bool
default:"False"
Whether to return logits along with loss (for debugging).
return_decoder_inputs
bool
default:"False"
Whether to return decoder inputs for beam search (inference mode).

Return Values

Model Architectures

Standard LLM-ASR

Input syntax:
audio [<lid> lang_id] <bos> text <eos>
from fairseq2.models.hub import load_model

model = load_model("omniASR_LLM_7B")
# ModelType.LLM_ASR_LID with language conditioning

Zero-Shot Model

Input syntax:
<context>
  (<context_example> ctx_audio <bos> ctx_text <eos> </context_example>) x N
</context>
target_audio <bos> target_text <eos>
model = load_model("omniASR_LLM_7B_ZS")
# ModelType.ZERO_SHOT with n_context_examples=10

Streaming Model

Input syntax:
[lang <lang>]
(audio_segment_i <segment_marker> <bos> text_i <eos>) x N
Segment markers:
  • <regular_segment>: For intermediate segments
  • <last_segment>: For final segment
model = load_model("omniASR_LLM_7B_Unlimited")
# ModelType.LLM_ASR_LID with streaming_config.is_streaming=True

Embedding Methods

embed_audio

def embed_audio(
    seqs: Tensor,
    seq_lens: List[int]
) -> Tuple[Tensor, List[int]]
Runs encoder and frontend on audio tensors.
seqs
Tensor
required
Audio waveforms [batch_size, time].
seq_lens
List[int]
required
Actual sequence lengths.
embedded_seqs
Tensor
Embedded audio [batch_size, reduced_time, model_dim].
embedded_seq_lens
List[int]
Reduced sequence lengths after encoder.

embed_text

def embed_text(
    seqs: Tensor,
    dtype: torch.dtype
) -> Tensor
Embeds text tokens.
seqs
Tensor
required
Text token indices [batch_size, seq_len].
dtype
torch.dtype
required
Target dtype for embeddings.
embedded
Tensor
Text embeddings [batch_size, seq_len, model_dim].

Training Example

import torch
from fairseq2.models.hub import load_model
from fairseq2.datasets import Seq2SeqBatch

# Load model
model = load_model("omniASR_LLM_7B")
model.train()

# Forward pass (training)
loss = model(batch)
loss.backward()

# Forward pass with logits (debugging)
loss, logits, *_ = model(batch, return_logits=True)
print(f"Loss: {loss.item()}, Logits shape: {logits.shape}")

Inference Example

from omnilingual_asr.models.inference import ASRInferencePipeline

# Use high-level pipeline (recommended)
pipeline = ASRInferencePipeline("omniASR_LLM_7B")
transcriptions = pipeline.transcribe(["audio.wav"])

# Low-level model access
model = load_model("omniASR_LLM_7B")
model.eval()

# Get decoder inputs
decoder_context, context_lens, audio_emb = model(
    batch,
    return_decoder_inputs=True
)

# Use beam search for generation
# (typically done by ASRInferencePipeline)

Model Variants

Model CardTypeParametersFeatures
omniASR_LLM_300MLLM_ASR_LID300MLanguage conditioning
omniASR_LLM_1BLLM_ASR_LID1BLanguage conditioning
omniASR_LLM_3BLLM_ASR_LID3BLanguage conditioning
omniASR_LLM_7BLLM_ASR_LID7BLanguage conditioning
omniASR_LLM_7B_ZSZERO_SHOT7BZero-shot learning
omniASR_LLM_7B_UnlimitedLLM_ASR_LID7BStreaming (unlimited length)

See Also

Source Reference

See implementation at src/omnilingual_asr/models/wav2vec2_llama/model.py:43

Build docs developers (and LLMs) love