Skip to main content

build_loaders

Creates PyTorch DataLoader objects for training and validation with configurable subsampling and reproducible shuffling.
def build_loaders(
    dataset_name: str,
    batch_size: int,
    train_subset: int | None,
    val_subset: int | None,
    seed: int = 42,
    num_workers: int = 2,
) -> tuple[DataLoader, DataLoader]
dataset_name
str
required
Name of the dataset to load. Supported values: "mnist", "fashion-mnist".
batch_size
int
required
Number of samples per batch for both training and validation loaders.
train_subset
int | None
required
Maximum number of training samples to use. Pass None to use the full training set. If specified, takes the first N samples.
val_subset
int | None
required
Maximum number of validation samples to use. Pass None to use the full validation set. If specified, takes the first N samples.
seed
int
default:"42"
Random seed for the data loader’s generator. Ensures reproducible shuffling of training batches.
num_workers
int
default:"2"
Number of worker processes for parallel data loading. Set to 0 for single-process loading.
train_loader
DataLoader
Training data loader with shuffling enabled and seeded random generator.
val_loader
DataLoader
Validation data loader with shuffling disabled for consistent evaluation.

Data Preprocessing

All datasets are preprocessed with the following transforms:
  1. ToTensor(): Converts PIL images to PyTorch tensors
  2. Normalize((0.5,), (0.5,)): Normalizes grayscale images to [-1, 1] range
The resulting tensors have shape (1, 28, 28) with values in the range [-1, 1].

Supported Datasets

mnist
torchvision.datasets.MNIST
Handwritten digits dataset (60,000 training + 10,000 test images, 28×28 grayscale, 10 classes).
fashion-mnist
torchvision.datasets.FashionMNIST
Fashion items dataset (60,000 training + 10,000 test images, 28×28 grayscale, 10 classes).
Datasets are automatically downloaded to the data/ directory if not already present. First run may take a few moments to download.

Usage Example

from edge_opt.data import build_loaders

# Load full MNIST dataset
train_loader, val_loader = build_loaders(
    dataset_name="mnist",
    batch_size=128,
    train_subset=None,
    val_subset=None,
    seed=42,
    num_workers=2
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Iterate through data
for images, labels in train_loader:
    print(f"Batch shape: {images.shape}")  # (128, 1, 28, 28)
    print(f"Labels shape: {labels.shape}")  # (128,)
    break

Error Handling

If an unsupported dataset name is provided, the function raises ValueError with a message listing supported datasets:
ValueError: Unsupported dataset 'cifar10'. Use one of: ['mnist', 'fashion-mnist']

Performance Considerations

  • num_workers: Set to 0 for debugging (single-process). Use 2-4 for training on CPU, or 4-8 on systems with many cores.
  • Subsampling: Use train_subset and val_subset for rapid prototyping and hyperparameter search.
  • Batch size: Larger batches improve throughput but require more memory. Typical values: 64-256.

Reproducibility

The seed parameter ensures reproducible shuffling of training data:
  • Same seed → same batch order across runs
  • Different seeds → different batch orders
  • Validation loader is never shuffled (no randomness)
For full reproducibility, also call set_deterministic(seed) from the Model module before creating data loaders.

Data Augmentation

Currently, no data augmentation is applied. To add augmentation, modify the transforms:
from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

See Also

  • Config - Use ExperimentConfig to manage data loading parameters
  • Model - Pass data loaders to training loops with SmallCNN

Build docs developers (and LLMs) love