Skip to main content

Overview

Efficient batch processing is critical for training and inference with ASR models. Omnilingual ASR provides two batching strategies and various optimization techniques to maximize throughput while managing memory constraints.

Batching Strategies

The framework supports two batching strategies defined in /src/omnilingual_asr/datasets/utils/batching.py:13:

Static Batching

Fixed number of sequences per batch, regardless of sequence length.
from omnilingual_asr.datasets.utils.batching import BatchingStrategy

asr_task_config:
  batching_strategy: BatchingStrategy.STATIC
  batch_size: 8
  drop_remainder: False
Characteristics:
  • Each batch contains exactly batch_size examples
  • Simpler to reason about for debugging
  • Can lead to memory spikes with long sequences
  • Less efficient GPU utilization
Use cases:
  • Small datasets with uniform audio lengths
  • Debugging and development
  • Inference with fixed batch sizes
Dynamic batching where each batch has a maximum number of elements (audio samples).
asr_task_config:
  batching_strategy: BatchingStrategy.LENGTH
  min_audio_len: 32_000
  max_audio_len: 960_000
  max_num_elements: 960_000
  num_seqs_multiple_of: 8
  drop_remainder: False
Characteristics:
  • Batches contain variable numbers of sequences
  • Total elements per batch ≤ max_num_elements
  • Number of sequences is a multiple of num_seqs_multiple_of
  • More efficient memory usage
  • Better GPU utilization
Use cases:
  • Training on diverse datasets
  • Production workloads
  • Memory-constrained environments

Configuration Parameters

Maximum total audio samples across all sequences in a batch.Type: Integer
Default: 3,200,000
Applies to: LENGTH batching
Example calculation:
  • Audio 1: 400,000 samples
  • Audio 2: 300,000 samples
  • Audio 3: 260,000 samples
  • Total: 960,000 ≤ max_num_elements ✓
asr_task_config:
  max_num_elements: 3_200_000  # ~200s at 16kHz
If max_num_elements % max_audio_len != 0, it will be rounded down automatically (see /src/omnilingual_asr/datasets/utils/batching.py:40-42).
Forces batch size to be a multiple of this value for hardware optimization.Type: Integer
Default: 8
Applies to: LENGTH batching
Why it matters:
  • Modern GPUs perform best with batch sizes that are multiples of 8 or 16
  • Enables efficient tensor core utilization on NVIDIA GPUs
  • Aligns with FSDP (Fully Sharded Data Parallel) requirements
asr_task_config:
  num_seqs_multiple_of: 8  # Batch sizes: 8, 16, 24, ...
Limits the maximum number of sequences in any single bucket.Type: Integer or None
Default: None
Applies to: LENGTH batching
Behavior:
  • Filters out buckets with more than max_bucket_size sequences
  • Prioritizes shorter sequences (fairseq2 buckets shortest sequences first)
  • Useful for preventing very small batches with long sequences
asr_task_config:
  max_bucket_size: 32  # No batch will have >32 sequences
Implementation: See /src/omnilingual_asr/datasets/utils/batching.py:50-56
Whether to drop the last incomplete batch.Type: Boolean
Default: False
Applies to: Both STATIC and LENGTH
Use cases:
  • Set to True for distributed training to ensure all workers have equal batches
  • Set to False for inference to process all data
asr_task_config:
  drop_remainder: True

Memory Optimization

Audio Length Filtering

Filter audio by length before batching to optimize memory usage:
asr_task_config:
  min_audio_len: 32_000    # 2 seconds at 16kHz
  max_audio_len: 960_000   # 60 seconds at 16kHz
Benefits:
  • Removes very short clips that don’t benefit training
  • Prevents OOM errors from extremely long sequences
  • Improves bucketing efficiency
Implementation: See /src/omnilingual_asr/datasets/tasks/asr_task.py:163-168

Bucket Size Calculation

The framework uses fairseq2.data.data_pipeline.create_bucket_sizes to automatically calculate optimal bucket sizes:
bucket_sizes = create_bucket_sizes(
    min_seq_len=min_audio_len,
    max_seq_len=max_audio_len,
    max_num_elements=max_num_elements,
    num_seqs_multiple_of=num_seqs_multiple_of,
)
This creates buckets of increasing sequence lengths, where each bucket satisfies:
  • bucket_size * seq_len ≤ max_num_elements
  • bucket_size % num_seqs_multiple_of == 0

Example Configuration

asr_task_config:
  batching_strategy: LENGTH
  min_audio_len: 16_000
  max_audio_len: 1_600_000  # 100 seconds
  max_num_elements: 6_400_000
  num_seqs_multiple_of: 16
  batch_shuffle_window: 1000

Parallel Processing

Multi-Partition Loading

For training with mixture parquet datasets, partitions can be loaded in parallel:
mixture_parquet_storage_config:
  max_workers: 30  # Parallel partition loading threads
  pa_cpu_count: 20 # PyArrow CPU threads
How it works:
  1. Each language-corpus partition is loaded by a separate thread
  2. Partitions are sampled according to their weights
  3. Examples are prefetched in background
Implementation: See /src/omnilingual_asr/datasets/storage/mixture_parquet_storage.py:486-545

Fragment Prefetching

Prefetch parquet fragments in the background:
mixture_parquet_storage_config:
  fragment_loading:
    nb_prefetch: 1              # Number of fragments to prefetch
    num_parallel_fragments: 1   # Parallel fragment readers
    use_threads: False          # Use threads for reading

Batch Prefetching

Prefetch processed batches while the model trains:
asr_task_config:
  num_prefetch: 4  # Prefetch 4 batches in background
Trade-offs:
  • Higher values improve throughput but increase memory usage
  • Recommended: 2-4 for training, 1 for inference

Shuffling Strategies

Example-Level Shuffling

Shuffle individual examples before batching:
asr_task_config:
  example_shuffle_window: 10000
  seed: 2
Options:
  • example_shuffle_window: 1 - No shuffling
  • example_shuffle_window: 0 - Load entire dataset and shuffle (not recommended - OOM risk)
  • example_shuffle_window: N - Shuffle within sliding window of N examples
Never set example_shuffle_window: 0 for large datasets - it loads everything into memory and will cause OOM errors.

Batch-Level Shuffling

Shuffle batches after bucketing:
asr_task_config:
  batch_shuffle_window: 1000
  seed: 2
Why shuffle batches:
  • Increases diversity of sequence lengths within an epoch
  • Prevents model from learning length-based patterns
  • Improves gradient stability
Implementation: See /src/omnilingual_asr/datasets/tasks/asr_task.py:336-346

Data Pipeline Architecture

The complete pipeline flow for LENGTH batching:
Parquet Storage

Filter by Audio Length (min/max)

Shuffle Examples (window-based)

Tokenize Text

Filter Unknown/Long Text

Bucket by Length

Shuffle Batches (window-based)

Decode Audio

Process to Waveform/Fbank

Collate to Seq2SeqBatch

Prefetch

Model

Performance Tips

Adjust max_num_elements based on available GPU memory:
# Rule of thumb: max_num_elements ≈ GPU_RAM_GB * 40_000
# A100 80GB: 3,200,000
# A100 40GB: 1,600,000
# V100 16GB: 640,000
Always use BatchingStrategy.LENGTH for training - it provides better memory efficiency and GPU utilization.
  • 8: Good default for most GPUs
  • 16: Better for A100 with tensor cores
  • Must be ≤ your minimum expected batch size
During training, log batch sizes to ensure efficient bucketing:
log.info(f"Batch size: {batch.source_seqs.shape[0]}")
log.info(f"Max sequence length: {batch.source_seqs.shape[1]}")
Start with low prefetch values and increase gradually:
# Conservative
num_prefetch: 2

# Aggressive (more memory)
num_prefetch: 8

Debugging Batch Issues

Check Bucket Sizes

from fairseq2.data.data_pipeline import create_bucket_sizes

bucket_sizes = create_bucket_sizes(
    min_seq_len=32_000,
    max_seq_len=960_000,
    max_num_elements=960_000,
    num_seqs_multiple_of=8,
)
print(bucket_sizes)
# [(8, 120000), (16, 60000), (24, 40000), ...]

Verify Memory Usage

import torch

# After creating a batch
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

Profile Data Loading

import time

for i, batch in enumerate(dataloader):
    start = time.time()
    # Process batch
    print(f"Batch {i}: {time.time() - start:.3f}s")

Build docs developers (and LLMs) love