Skip to main content
MLX provides functions to monitor and manage GPU memory usage. Understanding memory management is crucial for training large models and optimizing performance.

Overview

MLX uses a caching allocator to efficiently manage GPU memory. The allocator:
  • Caches freed memory for reuse
  • Reduces allocation overhead
  • Can be configured with limits
  • Provides detailed usage statistics

Memory Functions

get_active_memory

mlx.core.get_active_memory() -> int
Get the current amount of active GPU memory in bytes. Active memory includes all allocated arrays that are currently in use. Returns:
  • Active memory in bytes
Example:
import mlx.core as mx

# Check initial memory
print(f"Initial memory: {mx.get_active_memory() / 1024**2:.2f} MB")

# Allocate arrays
x = mx.random.normal((10000, 10000))
y = mx.random.normal((10000, 10000))
mx.eval(x, y)

print(f"After allocation: {mx.get_active_memory() / 1024**2:.2f} MB")

# Free arrays
del x, y

print(f"After deletion: {mx.get_active_memory() / 1024**2:.2f} MB")
Monitoring during training:
import mlx.core as mx
import mlx.nn as nn

class MemoryTracker:
    def __init__(self):
        self.history = []
    
    def track(self, label=""):
        mem_mb = mx.get_active_memory() / (1024**2)
        self.history.append((label, mem_mb))
        print(f"{label}: {mem_mb:.2f} MB")
    
    def report(self):
        print("\nMemory usage timeline:")
        for label, mem_mb in self.history:
            print(f"  {label:30s} {mem_mb:8.2f} MB")

tracker = MemoryTracker()

tracker.track("Start")
model = create_model()
tracker.track("After model creation")

for epoch in range(5):
    for batch in data_loader:
        loss, grads = compute_loss(model, batch)
        optimizer.update(model, grads)
    
    tracker.track(f"After epoch {epoch + 1}")

tracker.report()

get_peak_memory

mlx.core.get_peak_memory() -> int
Get the peak GPU memory usage in bytes since the last reset. This tracks the maximum active memory usage over time. Returns:
  • Peak memory in bytes
Example:
import mlx.core as mx

# Reset peak counter
mx.reset_peak_memory()

# Run computation
for i in range(10):
    size = (i + 1) * 1000
    x = mx.random.normal((size, size))
    y = x @ x.T
    mx.eval(y)
    
    current = mx.get_active_memory() / (1024**2)
    peak = mx.get_peak_memory() / (1024**2)
    print(f"Iteration {i}: Current={current:.1f} MB, Peak={peak:.1f} MB")

print(f"\nPeak memory usage: {mx.get_peak_memory() / (1024**2):.2f} MB")
Profiling a function:
import mlx.core as mx

def profile_memory(func, *args, **kwargs):
    """Profile peak memory usage of a function."""
    # Reset and record initial state
    mx.reset_peak_memory()
    initial = mx.get_active_memory()
    
    # Run function
    result = func(*args, **kwargs)
    mx.eval(result)
    
    # Collect statistics
    peak = mx.get_peak_memory()
    final = mx.get_active_memory()
    
    print(f"Memory profile for {func.__name__}:")
    print(f"  Initial: {initial / (1024**2):.2f} MB")
    print(f"  Peak:    {peak / (1024**2):.2f} MB")
    print(f"  Final:   {final / (1024**2):.2f} MB")
    print(f"  Leaked:  {(final - initial) / (1024**2):.2f} MB")
    
    return result

def my_computation(n):
    x = mx.random.normal((n, n))
    return x @ x.T

result = profile_memory(my_computation, 5000)

reset_peak_memory

mlx.core.reset_peak_memory() -> None
Reset the peak memory counter. Use this before profiling specific code sections to get accurate peak usage measurements. Example:
import mlx.core as mx

# Profile different operations
operations = {
    "matmul": lambda: mx.random.normal((5000, 5000)) @ mx.random.normal((5000, 5000)),
    "conv2d": lambda: mx.conv2d(mx.random.normal((8, 224, 224, 3)), mx.random.normal((64, 3, 3, 3))),
    "softmax": lambda: mx.softmax(mx.random.normal((1000, 10000)))
}

for name, op in operations.items():
    mx.reset_peak_memory()
    result = op()
    mx.eval(result)
    peak = mx.get_peak_memory() / (1024**2)
    print(f"{name:15s}: {peak:8.2f} MB peak")

get_cache_memory

mlx.core.get_cache_memory() -> int
Get the amount of cached GPU memory in bytes. Cached memory is memory that has been freed but not returned to the system, available for reuse. Returns:
  • Cached memory in bytes
Example:
import mlx.core as mx

def memory_status():
    active = mx.get_active_memory() / (1024**2)
    cached = mx.get_cache_memory() / (1024**2)
    total = active + cached
    
    print(f"Active:  {active:8.2f} MB")
    print(f"Cached:  {cached:8.2f} MB")
    print(f"Total:   {total:8.2f} MB")
    print(f"Efficiency: {active/total*100:.1f}% utilized")

print("Before allocation:")
memory_status()

# Allocate and free
x = mx.random.normal((10000, 10000))
mx.eval(x)
print("\nAfter allocation:")
memory_status()

del x
print("\nAfter deletion (memory cached):")
memory_status()

mx.clear_cache()
print("\nAfter clearing cache:")
memory_status()

clear_cache

mlx.core.clear_cache() -> None
Clear the memory cache and return memory to the system. Use this when you need to free up memory for other applications or when running multiple experiments sequentially. Example:
import mlx.core as mx

def train_model(config):
    # Train model
    model = create_model(config)
    train(model)
    
    # Clear memory before next run
    del model
    mx.clear_cache()
    
    print(f"Memory released: {mx.get_cache_memory() / (1024**2):.2f} MB")

# Train multiple models without memory accumulation
for config in configs:
    train_model(config)
Memory cleanup in notebooks:
import mlx.core as mx

# After a large experiment
del model, optimizer, data_loader
mx.clear_cache()

print(f"Memory freed: {mx.get_cache_memory() / (1024**2):.2f} MB")
print(f"Active memory: {mx.get_active_memory() / (1024**2):.2f} MB")

set_memory_limit

mlx.core.set_memory_limit(limit: int, relaxed: bool = True) -> None
Set the maximum amount of GPU memory MLX can use. Parameters:
  • limit (int): Memory limit in bytes (0 means unlimited)
  • relaxed (bool): If True, allows temporary exceeding of limit. Default: True
Example:
import mlx.core as mx
import mlx.core.metal as metal

if metal.is_available():
    info = metal.device_info()
    total_memory = info['memory']
    
    # Use at most 80% of available memory
    limit = int(total_memory * 0.8)
    mx.set_memory_limit(limit, relaxed=True)
    
    print(f"Memory limit set to {limit / (1024**3):.2f} GB")
Sharing GPU with other processes:
import mlx.core as mx

# Reserve 16 GB for other applications
reserved_gb = 16
limit = (64 - reserved_gb) * 1024**3  # Assuming 64 GB total

mx.set_memory_limit(limit, relaxed=False)
print(f"MLX limited to {limit / (1024**3):.0f} GB")
When relaxed=False, allocations that would exceed the limit will fail immediately. Use relaxed=True for more flexibility.

set_cache_limit

mlx.core.set_cache_limit(limit: int) -> None
Set the maximum amount of cached memory. When the cache exceeds this limit, memory is returned to the system. Parameters:
  • limit (int): Cache limit in bytes (0 means unlimited)
Example:
import mlx.core as mx

# Limit cache to 4 GB
cache_limit = 4 * 1024**3
mx.set_cache_limit(cache_limit)

print(f"Cache limit set to {cache_limit / (1024**3):.0f} GB")

# Cache will automatically be trimmed when it exceeds 4 GB
for i in range(100):
    x = mx.random.normal((1000, 1000))
    y = x @ x
    mx.eval(y)
    del x, y
    
    if i % 10 == 0:
        cached = mx.get_cache_memory() / (1024**2)
        print(f"Iteration {i}: Cached memory = {cached:.2f} MB")

set_wired_limit

mlx.core.set_wired_limit(limit: int) -> None
Set the maximum amount of wired (pinned) memory. Wired memory is locked in physical RAM and cannot be paged out. On Apple Silicon with unified memory, this affects memory available to the GPU. Parameters:
  • limit (int): Wired memory limit in bytes (0 means unlimited)
Example:
import mlx.core as mx

# Limit wired memory to 32 GB on a Mac with 64 GB RAM
wired_limit = 32 * 1024**3
mx.set_wired_limit(wired_limit)

print(f"Wired memory limit set to {wired_limit / (1024**3):.0f} GB")

Practical Examples

Memory-Efficient Training

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

class MemoryEfficientTrainer:
    def __init__(self, model, memory_limit_gb=None):
        self.model = model
        
        if memory_limit_gb:
            mx.set_memory_limit(memory_limit_gb * 1024**3, relaxed=True)
        
        # Set cache limit to 4 GB
        mx.set_cache_limit(4 * 1024**3)
    
    def train_epoch(self, data_loader):
        mx.reset_peak_memory()
        
        for i, batch in enumerate(data_loader):
            loss, grads = self.loss_fn(self.model, batch)
            self.optimizer.update(self.model, grads)
            
            # Periodic memory cleanup
            if i % 100 == 0:
                mx.clear_cache()
            
            # Periodic memory report
            if i % 50 == 0:
                active = mx.get_active_memory() / (1024**2)
                peak = mx.get_peak_memory() / (1024**2)
                print(f"Batch {i}: Active={active:.0f}MB, Peak={peak:.0f}MB")

trainer = MemoryEfficientTrainer(model, memory_limit_gb=40)
trainer.train_epoch(train_loader)

Memory Leak Detection

import mlx.core as mx

class LeakDetector:
    def __init__(self, tolerance_mb=10):
        self.tolerance = tolerance_mb * 1024**2
        self.baseline = None
    
    def start(self):
        mx.clear_cache()
        self.baseline = mx.get_active_memory()
        print(f"Baseline memory: {self.baseline / (1024**2):.2f} MB")
    
    def check(self, label=""):
        mx.clear_cache()
        current = mx.get_active_memory()
        leaked = current - self.baseline
        
        if leaked > self.tolerance:
            print(f"⚠️  Potential leak {label}: {leaked / (1024**2):.2f} MB")
            return True
        else:
            print(f"✓  No leak {label}: {leaked / (1024**2):.2f} MB")
            return False

detector = LeakDetector(tolerance_mb=5)
detector.start()

for i in range(10):
    x = mx.random.normal((1000, 1000))
    y = x @ x
    mx.eval(y)
    del x, y
    
    detector.check(f"iteration {i}")

Dynamic Batch Size Selection

import mlx.core as mx
import mlx.core.metal as metal

def find_optimal_batch_size(model, input_shape, max_memory_gb=None):
    """Find largest batch size that fits in memory."""
    if max_memory_gb is None and metal.is_available():
        info = metal.device_info()
        max_memory_gb = info['memory'] / (1024**3) * 0.8  # Use 80%
    
    batch_size = 1
    max_batch = 1
    
    while True:
        batch_size *= 2
        mx.clear_cache()
        mx.reset_peak_memory()
        
        try:
            # Try forward and backward pass
            x = mx.random.normal((batch_size,) + input_shape)
            output = model(x)
            loss = mx.mean(output)
            grads = mx.grad(lambda m, x: mx.mean(m(x)))(model, x)
            mx.eval(loss, grads)
            
            peak_gb = mx.get_peak_memory() / (1024**3)
            
            if peak_gb > max_memory_gb:
                break
            
            max_batch = batch_size
            print(f"Batch size {batch_size}: {peak_gb:.2f} GB - OK")
            
        except Exception as e:
            print(f"Batch size {batch_size} failed: {e}")
            break
    
    mx.clear_cache()
    print(f"\nOptimal batch size: {max_batch}")
    return max_batch

model = create_model()
batch_size = find_optimal_batch_size(model, input_shape=(3, 224, 224))

Memory Debugging Tips

  1. Use clear_cache() frequently: Especially in notebooks and between experiments
  2. Monitor peak memory: Track peak usage to find memory bottlenecks
  3. Set conservative limits: Leave headroom for memory spikes
  4. Profile incrementally: Add memory checks at key points in your code
  5. Delete large arrays: Use del explicitly and call clear_cache()
  6. Watch for leaks: Use the leak detector pattern shown above

See Also

Build docs developers (and LLMs) love