Skip to main content

Overview

Omnilingual ASR uses a modular data pipeline architecture separating storage (how data is read) from task (how data is processed). This enables flexible mixing of different storage backends with preprocessing pipelines.

MixtureParquetStorage

Parquet-based storage implementation with partition weighting and multilingual sampling.

Constructor

MixtureParquetStorage(
    path: Path,
    config: MixtureParquetStorageConfig
)
path
Path
required
Path to parquet dataset directory with language/corpus partitions.
config
MixtureParquetStorageConfig
required
Storage configuration including fragment streaming, loading, and weighting parameters.

Configuration: MixtureParquetStorageConfig

fragment_streaming
FragmentStreamingConfig
required
Controls how parquet fragments (row groups) are streamed:
  • seed: Random seed for shuffling
  • fragment_shuffle_window: Window size for fragment shuffling (-1 for global)
  • nb_epochs: Number of epochs (None for infinite)
fragment_loading
FragmentLoadingConfig
required
Controls how fragments are loaded into memory:
  • columns: Schema mapping (use LangASRSchema)
  • nb_prefetch: Number of fragments to prefetch
  • num_parallel_fragments: Parallel loading threads
  • cache: Enable caching of decoded audio
dataset_summary_path
str | None
Path to TSV file with corpus/language hour distribution for weighted sampling.
beta_corpus
float | None
Beta parameter for corpus weighting: weight = (hours/total)^beta.
beta_language
float | None
Beta parameter for language weighting within corpus.
sync_mode
SyncMode
default:"SyncMode.UNTIL_FIRST"
Synchronization mode for distributed training:
  • UNTIL_FIRST: Stop when first worker finishes (training)
  • UNTIL_LAST: Stop when last worker finishes (validation)
sync_batches
bool
default:"True"
Whether to synchronize batch sizes across workers.

Methods

create_raw_data_pipeline

def create_raw_data_pipeline(
    split: str,
    gangs: Gangs
) -> DataPipelineBuilder
Creates the raw data pipeline for reading parquet files.
split
str
required
Split name (e.g., “train”, “dev”, “test”). Can include corpus filter: “train_librispeech”.
gangs
Gangs
required
Gang configuration for distributed reading.
builder
DataPipelineBuilder
Pipeline builder yielding dictionaries with audio bytes, text, language, and corpus.

Example

from omnilingual_asr.datasets.storage import (
    MixtureParquetStorage,
    MixtureParquetStorageConfig,
    LangASRSchema
)
from fairseq2.data.parquet import (
    FragmentStreamingConfig,
    FragmentLoadingConfig
)
from fairseq2.datasets import SyncMode

# Configure storage
config = MixtureParquetStorageConfig(
    fragment_streaming=FragmentStreamingConfig(
        parquet_path="",
        seed=42,
        fragment_shuffle_window=-1,
        nb_epochs=None
    ),
    fragment_loading=FragmentLoadingConfig(
        columns=LangASRSchema(),
        cache=True
    ),
    dataset_summary_path="stats.tsv",
    beta_corpus=0.5,
    beta_language=0.5,
    sync_mode=SyncMode.UNTIL_FIRST
)

# Create storage
storage = MixtureParquetStorage(
    path=Path("data/asr_parquet"),
    config=config
)

# Create pipeline
builder = storage.create_raw_data_pipeline(
    split="train",
    gangs=gangs
)

AsrTask

ASR preprocessing pipeline including audio filtering, tokenization, and batching.

Constructor

AsrTask(config: AsrTaskConfig)
config
AsrTaskConfig
required
Task configuration for preprocessing pipeline.

Configuration: AsrTaskConfig

Audio Processing

min_audio_len
int
default:"1"
Minimum audio sequence length (in samples). Shorter audio is filtered out.
max_audio_len
int
default:"800000"
Maximum audio sequence length (~50s at 16kHz). Longer audio is filtered out.
normalize_audio
bool
default:"False"
Whether to normalize audio to zero mean and unit variance.
use_fbank
bool
default:"False"
Whether to use filterbank features instead of raw waveforms.

SpecAugment

spec_aug_p
float | None
default:"None"
Probability of applying SpecAugment per sample.
spec_aug_freq_mask_param
int
default:"80"
Maximum frequency mask length for SpecAugment.
spec_aug_time_mask_param
int
default:"80"
Maximum time mask length for SpecAugment.

Text Processing

filter_long_text_threshold
int | None
default:"None"
Maximum text length in tokens. Longer sequences are filtered out.
remove_unknown
bool
default:"False"
Whether to remove unknown tokens from text in-place.
min_samples_per_char
int
default:"160"
Minimum audio samples per character. Samples with faster speech are filtered out.

Batching

batching_strategy
BatchingStrategy
default:"LENGTH"
Batching strategy:
  • LENGTH: Dynamic batching by total elements (recommended)
  • STATIC: Fixed batch size
batch_size
int
default:"8"
Batch size for STATIC batching strategy.
max_num_elements
int
default:"3200000"
Maximum total elements per batch for LENGTH strategy.
num_seqs_multiple_of
int
default:"8"
Batch size must be multiple of this value (for hardware optimization).
drop_remainder
bool
default:"False"
Whether to drop last incomplete batch.

Pipeline Settings

example_shuffle_window
int
default:"0"
Sliding window size for shuffling examples before batching.
batch_shuffle_window
int
default:"1000"
Sliding window size for shuffling batches.
num_prefetch
int
default:"4"
Number of batches to prefetch in background.
npc
int
default:"10"
Number of parallel calls for data pipeline operations.

Methods

apply_processing_pipeline

def apply_processing_pipeline(
    builder: DataPipelineBuilder,
    gangs: Gangs,
    tokenizer: Tokenizer,
    dtype: torch.dtype
) -> DataPipelineBuilder
Applies the complete ASR preprocessing pipeline.
builder
DataPipelineBuilder
required
Input pipeline builder (typically from storage layer).
gangs
Gangs
required
Gang configuration.
tokenizer
Tokenizer
required
Tokenizer for text encoding.
dtype
torch.dtype
required
Data type for audio tensors.
builder
DataPipelineBuilder
Pipeline builder yielding Seq2SeqBatch objects.

Pipeline Stages

The ASR task pipeline processes data in the following order:
  1. Audio Filtering: Filter by length (min_audio_len, max_audio_len)
  2. Example Shuffling: Shuffle before batching (example_shuffle_window)
  3. Text Tokenization: Encode text with tokenizer
  4. Text Filtering: Filter empty text, unknown sequences, long text
  5. Batching: Bucket by audio length or static batch size
  6. Batch Shuffling: Shuffle batches (batch_shuffle_window)
  7. Audio Decoding: Decode audio bytes to waveforms
  8. Audio Processing: Normalize, convert to mono, optionally apply SpecAugment
  9. Feature Extraction: Extract fbank features (if use_fbank=True)
  10. Collation: Collate into padded batches
  11. Prefetching: Prefetch batches in background
  12. Seq2SeqBatch Conversion: Convert to final batch format

Example

from omnilingual_asr.datasets.tasks import AsrTask, AsrTaskConfig
import torch

# Configure task
config = AsrTaskConfig(
    min_audio_len=8000,
    max_audio_len=800_000,
    normalize_audio=True,
    spec_aug_p=0.5,  # 50% probability of SpecAugment
    spec_aug_freq_mask_param=27,
    spec_aug_time_mask_param=100,
    max_num_elements=3_200_000,
    num_seqs_multiple_of=8,
    example_shuffle_window=1000,
    batch_shuffle_window=100,
    num_prefetch=4
)

# Create task
task = AsrTask(config)

# Apply to storage pipeline
builder = storage.create_raw_data_pipeline("train", gangs)
builder = task.apply_processing_pipeline(
    builder,
    gangs=gangs,
    tokenizer=tokenizer,
    dtype=torch.bfloat16
)

# Finalize pipeline
pipeline = builder.and_return()

Audio Preprocessing

Audio Decoding

Audio bytes are decoded using fairseq2’s AudioDecoder:
from fairseq2.data.audio import AudioDecoder

audio_decoder = AudioDecoder(dtype=torch.float32)
# Decodes to dict: {"waveform": Tensor, "sample_rate": int}

Normalization

def apply_audio_normalization(waveform: Tensor) -> Tensor:
    """Normalize to zero mean and unit variance."""
    return layer_norm(waveform, waveform.shape)

Mono Conversion

def convert_to_mono(waveform: Tensor) -> Tensor:
    """Convert multi-channel to mono by averaging."""
    if waveform.dim() == 2:
        waveform = waveform.mean(dim=1)
    return waveform

SpecAugment

Applied with probability spec_aug_p:
  1. Convert waveform to spectrogram
  2. Apply frequency masking (random mask of length up to spec_aug_freq_mask_param)
  3. Apply time masking (random mask of length up to spec_aug_time_mask_param)
  4. Convert back to waveform

Filterbank Features

If use_fbank=True:
from fairseq2.data.audio import WaveformToFbankConverter

fbank_converter = WaveformToFbankConverter(
    num_mel_bins=80,
    waveform_scale=2**15,
    channel_last=True,
    standardize=True,
    dtype=dtype
)

Schema Definitions

LangASRSchema

Column mapping for parquet datasets:
@dataclass
class LangASRSchema(NamedColumns):
    audio: str = "audio_bytes"      # Audio data column
    length: str = "audio_size"       # Audio length column
    text: str = "text"               # Transcription column
    split: str = "split"             # Split column
    lang: str = "language"           # Language column
    corpus: str = "corpus"           # Corpus column

Complete Example: Training Pipeline

from pathlib import Path
import torch
from fairseq2.data.tokenizers import load_tokenizer
from fairseq2.gang import FakeGang
from fairseq2.datasets import SyncMode

from omnilingual_asr.datasets.impl import MixtureParquetAsrDataset
from omnilingual_asr.datasets.storage import (
    MixtureParquetStorageConfig,
    LangASRSchema
)
from omnilingual_asr.datasets.tasks import AsrTaskConfig

# Storage configuration
storage_config = MixtureParquetStorageConfig(
    dataset_summary_path="stats.tsv",
    beta_corpus=0.5,
    beta_language=0.5,
    sync_mode=SyncMode.UNTIL_FIRST,
    sync_batches=True
)

# Task configuration
task_config = AsrTaskConfig(
    normalize_audio=True,
    spec_aug_p=0.5,
    max_num_elements=3_200_000,
    example_shuffle_window=1000,
    batch_shuffle_window=100
)

# Create dataset and reader
dataset = MixtureParquetAsrDataset.from_path(Path("data/asr"))
tokenizer = load_tokenizer("omniASR_tokenizer_v1")

reader = dataset.create_reader(
    split="train",
    tokenizer=tokenizer,
    gangs=FakeGang(device=torch.device("cuda")),
    dtype=torch.bfloat16,
    num_accumulate=1,
    storage_config=storage_config,
    task_config=task_config
)

# Training loop
for batches in reader:
    for batch in batches:
        # batch.source_seqs: [B, T] audio
        # batch.target_seqs: [B, S] text tokens
        loss = model(batch)
        loss.backward()

Source References

  • MixtureParquetStorage: src/omnilingual_asr/datasets/storage/mixture_parquet_storage.py:133
  • AsrTask: src/omnilingual_asr/datasets/tasks/asr_task.py:140
  • Audio utilities: src/omnilingual_asr/datasets/utils/audio.py

Build docs developers (and LLMs) love