Skip to main content

Utility Functions

The torch.utils package provides various utility functions and helpers for PyTorch development.

Data Loading Utilities

DataLoader

torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    multiprocessing_context=None,
    generator=None,
    prefetch_factor=2,
    persistent_workers=False
)
Data loader that combines a dataset and a sampler, and provides an iterable over the given dataset.
dataset
Dataset
Dataset from which to load the data.
batch_size
int
default:"1"
How many samples per batch to load.
shuffle
bool
default:"False"
Set to True to have the data reshuffled at every epoch.
num_workers
int
default:"0"
How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
pin_memory
bool
default:"False"
If True, the data loader will copy Tensors into CUDA pinned memory before returning them.

Checkpoint Utilities

checkpoint

torch.utils.checkpoint.checkpoint(
    function,
    *args,
    use_reentrant=True,
    **kwargs
)
Checkpoint a model or part of the model. Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass.
function
callable
Describes what to run in the forward pass of the model or part of the model.
use_reentrant
bool
default:"True"
Specify whether to use the activation checkpointing or checkpoint variant.
Returns
Any
Output of running function on *args

Model Utilities

Module Utilities

torch.utils.set_module(obj, name)
Set the __module__ attribute on a given object.

Hooks

torch.utils.hooks.RemovableHandle
A handle which provides the capability to remove a hook.

Bottleneck

torch.utils.bottleneck.main()
Runs the bottleneck analysis tool. This is a tool that can be used to identify performance bottlenecks in your code.

TensorBoard

SummaryWriter

torch.utils.tensorboard.SummaryWriter(
    log_dir=None,
    comment='',
    purge_step=None,
    max_queue=10,
    flush_secs=120,
    filename_suffix=''
)
Writes entries directly to event files in the log_dir to be consumed by TensorBoard.
log_dir
str
default:"None"
Save directory location. Default is runs/CURRENT_DATETIME_HOSTNAME.
comment
str
default:"''"
Comment log_dir suffix appended to the default log_dir.
flush_secs
int
default:"120"
How often, in seconds, to flush the pending events to disk.

Logging Methods

add_scalar(tag, scalar_value, global_step=None, walltime=None)
Add scalar data to summary.
add_scalars(main_tag, tag_scalar_dict, global_step=None, walltime=None)
Adds many scalar data to summary.
add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None)
Add histogram to summary.
add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')
Add image data to summary.

C++ Extension

torch.utils.cpp_extension.load(
    name,
    sources,
    extra_cflags=None,
    extra_cuda_cflags=None,
    extra_ldflags=None,
    extra_include_paths=None,
    build_directory=None,
    verbose=False,
    with_cuda=None,
    is_python_module=True,
    is_standalone=False
)
Loads a PyTorch C++ extension just-in-time (JIT).
name
str
The name of the extension to build.
sources
list[str]
A list of relative or absolute paths to C++ source files.
verbose
bool
default:"False"
If True, turns on verbose logging of the build.
with_cuda
bool
default:"None"
Determines whether CUDA headers and libraries are added to the build.

Example Usage

import torch
from torch.utils.data import DataLoader, TensorDataset

# Create a simple dataset
data = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))
dataset = TensorDataset(data, labels)

# Create a DataLoader
loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=2
)

# Iterate through batches
for batch_data, batch_labels in loader:
    print(batch_data.shape, batch_labels.shape)

Build docs developers (and LLMs) love