The context_profiler module provides convenient decorators and context managers for profiling PyTorch code.
Functions
get_global_profiler()
Get or create the global profiler instance.
from gpumemprof import get_global_profiler
profiler = get_global_profiler(device="cuda:0")
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use. If None, uses current device
set_global_profiler()
Set the global profiler instance.
from gpumemprof import set_global_profiler, GPUMemoryProfiler
custom_profiler = GPUMemoryProfiler(device="cuda:1", track_tensors=True)
set_global_profiler(custom_profiler)
Profiler instance to set as global
profile_function()
Decorator to profile a function’s GPU memory usage.
from gpumemprof import profile_function
@profile_function
def train_step(model, batch):
output = model(batch)
return output
# With custom name
@profile_function(name="forward_pass")
def forward(x):
return model(x)
# With custom profiler
@profile_function(profiler=my_profiler, device="cuda:1")
def compute(data):
return process(data)
func
Optional[Callable]
default:"None"
Function to profile (when used as @profile_function)
name
Optional[str]
default:"None"
Custom name for the profiled function
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use for profiling
profiler
Optional[GPUMemoryProfiler]
default:"None"
Custom profiler instance to use
profile_context()
Context manager for profiling a block of code.
from gpumemprof import profile_context
with profile_context("data_loading") as prof:
data = load_data()
data = preprocess(data)
# With custom profiler
with profile_context("training", profiler=my_profiler) as prof:
model.train()
Name for the profiled context
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use
profiler
Optional[GPUMemoryProfiler]
default:"None"
Custom profiler instance
start_monitoring()
Start global memory monitoring.
from gpumemprof import start_monitoring, stop_monitoring
start_monitoring(interval=0.1)
# ... run your code ...
stop_monitoring()
Monitoring interval in seconds
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to monitor
stop_monitoring()
Stop global memory monitoring.
from gpumemprof import stop_monitoring
stop_monitoring()
get_summary()
Get global profiler summary.
from gpumemprof import get_summary
summary = get_summary()
print(summary)
Summary of profiling results
clear_results()
Clear global profiler results.
from gpumemprof import clear_results
clear_results()
get_profile_results()
Get recent profile results from global profiler.
from gpumemprof import get_profile_results
# Get all results
results = get_profile_results()
# Get last 10 results
recent = get_profile_results(limit=10)
limit
Optional[int]
default:"None"
Maximum number of recent results to return
List of profiling results
profile_model_training()
Profile an entire training loop.
from gpumemprof import profile_model_training
results = profile_model_training(
model=my_model,
train_loader=dataloader,
epochs=1,
device="cuda:0"
)
print(f"Batches profiled: {len(results['batch_results'])}")
print(f"Total memory allocated: {results['overall_summary']['total_memory_allocated']}")
DataLoader for training data
Number of epochs to profile
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use
Dictionary with batch results, epoch summaries, and overall summary
Classes
ProfiledModule
Wrapper for PyTorch modules that automatically profiles forward passes.
from gpumemprof import ProfiledModule
# Wrap a module
model = MyModel()
profiled_model = ProfiledModule(model, name="my_model")
# Use normally - forward passes are profiled automatically
output = profiled_model(input_data)
Constructor
name
Optional[str]
default:"None"
Name for profiling (defaults to class name)
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to use
profiler
Optional[GPUMemoryProfiler]
default:"None"
Custom profiler instance
Methods
forward()
Forward pass with automatic profiling.
output = profiled_model(input_data)
MemoryProfiler
High-level memory profiler with convenient methods.
from gpumemprof import MemoryProfiler
profiler = MemoryProfiler(device="cuda:0")
Constructor
device
Optional[Union[str, int, torch.device]]
default:"None"
GPU device to profile
Methods
start_monitoring()
Start continuous memory monitoring.
profiler.start_monitoring(interval=0.5)
Monitoring interval in seconds
stop_monitoring()
Stop continuous memory monitoring.
profiler.stop_monitoring()
profile()
Profile a function call.
result = profiler.profile(my_function, arg1, arg2, kwarg=value)
Function keyword arguments
context()
Context manager for profiling code blocks.
with profiler.context("preprocessing"):
data = preprocess(raw_data)
wrap_module()
Wrap a PyTorch module for automatic profiling.
profiled_model = profiler.wrap_module(model, name="resnet50")
name
Optional[str]
default:"None"
Name for profiling
Wrapped module with automatic profiling
get_summary()
Get profiling summary.
summary = profiler.get_summary()
clear()
Clear profiling results.
save_results()
Save profiling results to file.
profiler.save_results("profile_results.json")
load_results()
Load profiling results from file.
data = profiler.load_results("profile_results.json")
Context Manager Support
with MemoryProfiler() as profiler:
profiler.start_monitoring()
# ... code ...
Example Usage
import torch
from gpumemprof import (
profile_function,
profile_context,
ProfiledModule,
MemoryProfiler
)
# Decorator pattern
@profile_function(name="data_processing")
def process_data(data):
return data.cuda() * 2
# Context manager pattern
with profile_context("model_forward"):
output = model(input_data)
# Module wrapping
profiled_model = ProfiledModule(model, name="resnet")
output = profiled_model(data) # Automatically profiled
# High-level profiler
profiler = MemoryProfiler()
profiler.start_monitoring()
with profiler.context("training_epoch"):
for batch in dataloader:
loss = train_step(batch)
profiler.stop_monitoring()
summary = profiler.get_summary()
profiler.save_results("training_profile.json")