Skip to main content

Introduction

Streaming datasets enable efficient data loading for large-scale training by:
  • Downloading data on-demand during training
  • Reducing local storage requirements
  • Supporting deterministic shuffling
  • Enabling infinite data iteration

Why Streaming Datasets?

Scalability

Train on datasets larger than available disk space by streaming from cloud storage

Efficiency

Start training immediately without waiting for full dataset downloads

Flexibility

Deterministic shuffling across epochs with configurable seed values

Cost Savings

Reduce storage costs by caching only actively used data

MosaicML Streaming

MosaicML Streaming provides an efficient format for training data.

Installation

uv sync
# or
pip install mosaicml-streaming

Creating a Streaming Dataset

Convert your data to MDS (Mosaic Data Shard) format:
streaming-dataset/mock_data.py
from pathlib import Path
import numpy as np
from PIL import Image
from streaming import MDSWriter

def create_data(
    path_to_save: Path = Path("mds-dataset"), 
    size: int = 100_000
):
    # Define schema
    columns = {
        "image": "jpeg",
        "class": "int"
    }
    compression = "zstd"
    
    # Write dataset
    with MDSWriter(
        out=str(path_to_save), 
        columns=columns, 
        compression=compression
    ) as out:
        for _ in range(size):
            sample = {
                "image": Image.fromarray(
                    np.random.randint(0, 256, (32, 32, 3), np.uint8)
                ),
                "class": np.random.randint(10),
            }
            out.write(sample)
Key parameters:
  • columns: Schema definition with field types
  • compression: Algorithm (zstd, gzip, snappy, none)
  • out: Local directory path

Supported Data Types

TypeDescriptionExample
intInteger valuesLabels, IDs
strText stringsPrompts, captions
bytesRaw bytesCustom encodings
jpegJPEG imagesPhotos
pngPNG imagesGraphics
pklPickle objectsComplex types
jsonJSON objectsMetadata

Upload to Cloud Storage

1

Create bucket

export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
export AWS_ENDPOINT_URL=http://127.0.0.1:9000

aws s3api create-bucket --bucket datasets
2

Generate data

python streaming-dataset/mock_data.py create-data \
  --path-to-save random-data
3

Upload to S3

aws s3 cp --recursive random-data s3://datasets/random-data

Consuming Streaming Data

Load and train from remote storage:
streaming-dataset/mock_data.py
from pathlib import Path
from streaming import StreamingDataset
from torch.utils.data import DataLoader

def get_dataloader(
    remote: str = "s3://datasets/random-data",
    local_cache: Path = Path("cache")
):
    # Create streaming dataset
    dataset = StreamingDataset(
        local=str(local_cache),
        remote=remote,
        shuffle=True
    )
    
    print(f"Dataset: {dataset}")
    
    # Access individual samples
    sample = dataset[42]
    print(f"Image: {sample['image']}, Class: {sample['class']}")
    
    # Create DataLoader
    dataloader = DataLoader(dataset, batch_size=32, num_workers=4)
    print(f"DataLoader: {dataloader}")
    
    return dataloader

CLI Usage

Create and consume data:
# Create dataset locally
python streaming-dataset/mock_data.py create-data --path-to-save random-data

# Read from remote
python streaming-dataset/mock_data.py get-dataloader --remote random-data

Training Integration

Integrate with PyTorch training loops:
from streaming import StreamingDataset
from torch.utils.data import DataLoader
import torch

# Setup dataset
dataset = StreamingDataset(
    local="./cache",
    remote="s3://datasets/training-data",
    shuffle=True,
    shuffle_seed=42,  # Deterministic shuffling
    batch_size=64
)

# Create DataLoader
loader = DataLoader(
    dataset,
    batch_size=64,
    num_workers=8,
    pin_memory=True
)

# Training loop
for epoch in range(num_epochs):
    for batch in loader:
        images = batch['image']
        labels = batch['class']
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Advanced Features

Control shuffle behavior across epochs:
dataset = StreamingDataset(
    local="./cache",
    remote="s3://data",
    shuffle=True,
    shuffle_algo='naive',  # or 'py1b', 'py1s', 'py2s'
    shuffle_seed=42,
    shuffle_block_size=1 << 18
)
Benefits:
  • Reproducible training runs
  • Different shuffle per epoch
  • Efficient block-level shuffling
Shard data across workers:
dataset = StreamingDataset(
    local="./cache",
    remote="s3://data",
    partition_algo='orig',  # Partition strategy
    num_canonical_nodes=4,  # Physical nodes
    batch_size=64
)
Automatically handles:
  • Data sharding per worker
  • Epoch boundaries
  • Sample uniqueness
Control local cache behavior:
dataset = StreamingDataset(
    local="./cache",
    remote="s3://data",
    cache_limit="100gb",  # Max cache size
    download_retry=3,     # Retry failed downloads
    download_timeout=120  # Timeout per download
)
Cache is managed via LRU eviction.
Monitor download and processing:
import streaming.profiler as profiler

with profiler.profile() as prof:
    for batch in dataloader:
        process(batch)

prof.print_stats()

Alternative Solutions

TensorFlow’s format for efficient serialization:
import tensorflow as tf

# Write
with tf.io.TFRecordWriter('data.tfrecord') as writer:
    for sample in dataset:
        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'image': _bytes_feature(sample['image']),
                    'label': _int64_feature(sample['label']),
                }
            )
        )
        writer.write(example.SerializeToString())

# Read
dataset = tf.data.TFRecordDataset('data.tfrecord')
TFRecord Tutorial

Best Practices

Shard Size

  • Target 50-500MB per shard
  • Balance parallelism vs overhead
  • Consider download time

Compression

  • Use zstd for best ratio/speed
  • Disable for pre-compressed data (JPEG)
  • Test impact on throughput

Caching

  • Size cache > single epoch
  • Use fast local storage (SSD)
  • Monitor cache hit rate

Workers

  • Match to CPU cores
  • Increase for I/O-bound tasks
  • Profile to find optimal count

Resources

Next Steps

Build docs developers (and LLMs) love