Skip to main content
The context_profiler module provides convenient decorators and context managers for profiling PyTorch code.

Functions

get_global_profiler()

Get or create the global profiler instance.
from gpumemprof import get_global_profiler

profiler = get_global_profiler(device="cuda:0")
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use. If None, uses current device
return
GPUMemoryProfiler
Global profiler instance

set_global_profiler()

Set the global profiler instance.
from gpumemprof import set_global_profiler, GPUMemoryProfiler

custom_profiler = GPUMemoryProfiler(device="cuda:1", track_tensors=True)
set_global_profiler(custom_profiler)
profiler
GPUMemoryProfiler
Profiler instance to set as global

profile_function()

Decorator to profile a function’s GPU memory usage.
from gpumemprof import profile_function

@profile_function
def train_step(model, batch):
    output = model(batch)
    return output

# With custom name
@profile_function(name="forward_pass")
def forward(x):
    return model(x)

# With custom profiler
@profile_function(profiler=my_profiler, device="cuda:1")
def compute(data):
    return process(data)
func
Optional[Callable]
default:"None"
Function to profile (when used as @profile_function)
name
Optional[str]
default:"None"
Custom name for the profiled function
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use for profiling
profiler
Optional[GPUMemoryProfiler]
default:"None"
Custom profiler instance to use

profile_context()

Context manager for profiling a block of code.
from gpumemprof import profile_context

with profile_context("data_loading") as prof:
    data = load_data()
    data = preprocess(data)

# With custom profiler
with profile_context("training", profiler=my_profiler) as prof:
    model.train()
name
str
default:"'context'"
Name for the profiled context
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use
profiler
Optional[GPUMemoryProfiler]
default:"None"
Custom profiler instance

start_monitoring()

Start global memory monitoring.
from gpumemprof import start_monitoring, stop_monitoring

start_monitoring(interval=0.1)
# ... run your code ...
stop_monitoring()
interval
float
default:"0.1"
Monitoring interval in seconds
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to monitor

stop_monitoring()

Stop global memory monitoring.
from gpumemprof import stop_monitoring

stop_monitoring()

get_summary()

Get global profiler summary.
from gpumemprof import get_summary

summary = get_summary()
print(summary)
return
Dict[str, Any]
Summary of profiling results

clear_results()

Clear global profiler results.
from gpumemprof import clear_results

clear_results()

get_profile_results()

Get recent profile results from global profiler.
from gpumemprof import get_profile_results

# Get all results
results = get_profile_results()

# Get last 10 results
recent = get_profile_results(limit=10)
limit
Optional[int]
default:"None"
Maximum number of recent results to return
return
List[ProfileResult]
List of profiling results

profile_model_training()

Profile an entire training loop.
from gpumemprof import profile_model_training

results = profile_model_training(
    model=my_model,
    train_loader=dataloader,
    epochs=1,
    device="cuda:0"
)

print(f"Batches profiled: {len(results['batch_results'])}")
print(f"Total memory allocated: {results['overall_summary']['total_memory_allocated']}")
model
torch.nn.Module
PyTorch model to profile
train_loader
Any
DataLoader for training data
epochs
int
default:"1"
Number of epochs to profile
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use
return
Dict[str, Any]
Dictionary with batch results, epoch summaries, and overall summary

Classes

ProfiledModule

Wrapper for PyTorch modules that automatically profiles forward passes.
from gpumemprof import ProfiledModule

# Wrap a module
model = MyModel()
profiled_model = ProfiledModule(model, name="my_model")

# Use normally - forward passes are profiled automatically
output = profiled_model(input_data)

Constructor

module
torch.nn.Module
PyTorch module to wrap
name
Optional[str]
default:"None"
Name for profiling (defaults to class name)
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use
profiler
Optional[GPUMemoryProfiler]
default:"None"
Custom profiler instance

Methods

forward()
Forward pass with automatic profiling.
output = profiled_model(input_data)

MemoryProfiler

High-level memory profiler with convenient methods.
from gpumemprof import MemoryProfiler

profiler = MemoryProfiler(device="cuda:0")

Constructor

device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to profile

Methods

start_monitoring()
Start continuous memory monitoring.
profiler.start_monitoring(interval=0.5)
interval
float
default:"0.1"
Monitoring interval in seconds
stop_monitoring()
Stop continuous memory monitoring.
profiler.stop_monitoring()
profile()
Profile a function call.
result = profiler.profile(my_function, arg1, arg2, kwarg=value)
func
Callable
Function to profile
*args
Any
Function arguments
**kwargs
Any
Function keyword arguments
return
ProfileResult
Profiling result
context()
Context manager for profiling code blocks.
with profiler.context("preprocessing"):
    data = preprocess(raw_data)
name
str
default:"'context'"
Name for the context
wrap_module()
Wrap a PyTorch module for automatic profiling.
profiled_model = profiler.wrap_module(model, name="resnet50")
module
torch.nn.Module
Module to wrap
name
Optional[str]
default:"None"
Name for profiling
return
ProfiledModule
Wrapped module with automatic profiling
get_summary()
Get profiling summary.
summary = profiler.get_summary()
clear()
Clear profiling results.
profiler.clear()
save_results()
Save profiling results to file.
profiler.save_results("profile_results.json")
filename
str
Output filename
load_results()
Load profiling results from file.
data = profiler.load_results("profile_results.json")
filename
str
Input filename
return
Dict[str, Any]
Loaded profiling data

Context Manager Support

with MemoryProfiler() as profiler:
    profiler.start_monitoring()
    # ... code ...

Example Usage

import torch
from gpumemprof import (
    profile_function,
    profile_context,
    ProfiledModule,
    MemoryProfiler
)

# Decorator pattern
@profile_function(name="data_processing")
def process_data(data):
    return data.cuda() * 2

# Context manager pattern
with profile_context("model_forward"):
    output = model(input_data)

# Module wrapping
profiled_model = ProfiledModule(model, name="resnet")
output = profiled_model(data)  # Automatically profiled

# High-level profiler
profiler = MemoryProfiler()
profiler.start_monitoring()

with profiler.context("training_epoch"):
    for batch in dataloader:
        loss = train_step(batch)

profiler.stop_monitoring()
summary = profiler.get_summary()
profiler.save_results("training_profile.json")

Build docs developers (and LLMs) love