Skip to main content

TFMemoryProfiler

Main TensorFlow GPU memory profiler class for capturing memory snapshots and profiling functions.

Constructor

TFMemoryProfiler(
    device: Optional[str] = None,
    enable_tensor_tracking: bool = True
)
device
str
TensorFlow device name (e.g., ‘/GPU:0’, ‘/CPU:0’). Defaults to auto-detected device.
enable_tensor_tracking
bool
default:"True"
Whether to track individual tensor lifecycles

Methods

capture_snapshot

Capture current memory state at a point in time.
def capture_snapshot(self, name: str = "snapshot") -> MemorySnapshot
name
str
default:"snapshot"
Name identifier for the snapshot
MemorySnapshot
object
Snapshot containing timestamp, GPU/CPU memory usage, tensor count, and utilization metrics

profile_function

Decorator to profile function memory usage.
def profile_function(self, func: Callable[..., Any]) -> Callable[..., Any]
func
Callable
Function to profile
Example:
profiler = TFMemoryProfiler()

@profiler.profile_function
def train_model(model, data):
    # Training code
    pass

profile_context

Context manager for profiling code blocks.
@contextmanager
def profile_context(self, name: str = "context") -> Iterator[None]
name
str
default:"context"
Name for the profiling context
Example:
profiler = TFMemoryProfiler()

with profiler.profile_context("training_epoch"):
    model.fit(x_train, y_train)

start_continuous_profiling

Start continuous memory profiling in background thread.
def start_continuous_profiling(self, interval: float = 1.0) -> None
interval
float
default:"1.0"
Sampling interval in seconds

stop_continuous_profiling

Stop continuous memory profiling.
def stop_continuous_profiling(self) -> None

get_results

Get comprehensive profiling results.
def get_results(self) -> ProfileResult
ProfileResult
object
peak_memory_mb
float
Peak memory usage in MB
average_memory_mb
float
Average memory usage in MB
min_memory_mb
float
Minimum memory usage in MB
duration
float
Total profiling duration in seconds
memory_growth_rate
float
Memory growth rate in MB/second
snapshots
List[MemorySnapshot]
List of captured memory snapshots
function_profiles
Dict[str, Dict]
Per-function profiling statistics

reset

Reset profiler state and clear all data.
def reset(self) -> None

Usage as context manager

The profiler can be used as a context manager:
with TFMemoryProfiler() as profiler:
    # Your code here
    model.fit(x_train, y_train)
    
results = profiler.get_results()

MemorySnapshot

Represents a point-in-time memory snapshot.
timestamp
float
Unix timestamp of snapshot
name
str
Snapshot identifier
gpu_memory_mb
float
Current GPU memory usage in MB
cpu_memory_mb
float
Current CPU memory usage in MB
gpu_memory_reserved_mb
float
Reserved GPU memory in MB
gpu_utilization
float
GPU memory utilization percentage
num_tensors
int
Number of active tensors

ProfileResult

Comprehensive profiling results dataclass.
start_time
float
Start timestamp
end_time
float
End timestamp
peak_memory_mb
float
Peak memory usage
average_memory_mb
float
Average memory usage
total_allocations
int
Total memory allocations detected
total_deallocations
int
Total memory deallocations detected

Properties

duration: Total profiling duration in seconds memory_growth_rate: Memory growth rate in MB/second

Example

import tensorflow as tf
from tfmemprof import TFMemoryProfiler

# Create profiler
profiler = TFMemoryProfiler(device='/GPU:0')

# Profile a function
@profiler.profile_function
def build_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

# Capture snapshots
profiler.capture_snapshot('before_training')

# Profile code blocks
with profiler.profile_context('training'):
    model = build_model()
    model.compile(optimizer='adam', loss='categorical_crossentropy')
    model.fit(x_train, y_train, epochs=5)

profiler.capture_snapshot('after_training')

# Get results
results = profiler.get_results()
print(f"Peak memory: {results.peak_memory_mb:.2f} MB")
print(f"Average memory: {results.average_memory_mb:.2f} MB")
print(f"Memory growth rate: {results.memory_growth_rate:.2f} MB/s")

Build docs developers (and LLMs) love