Skip to main content

Overview

Linear RNNs in PyTorch require special handling during inference. Following the approach from Mamba, lrnnx implements CUDA graphs-based inference which reduces CPU overhead and provides >10x speedup compared to a simple for loop.
The generation API is located in lrnnx/utils/generation.py and works with all models in the library.

Quick Start

1

Import the generation API

import torch
from lrnnx.utils.generation import capture_graph, generate
from lrnnx.models.lti import LRU
2

Create and prepare model

# Initialize model in eval mode on CUDA
model = LRU(d_model=64, d_state=64).cuda().eval()

# Model configuration
batch_size = 4
H = 64  # d_model
3

Capture CUDA graph (one-time setup)

# Capture graph once - reuse for all generations
cache = capture_graph(model, batch_size=batch_size, H=H)
Graph capture is a one-time operation. The captured graph is tied to the specific batch size.
4

Generate sequences

# Create seed input (B, H)
x0 = torch.randn(batch_size, H, device="cuda")

# Generate 512 steps with CUDA graph replay
output = generate(
    model,
    x0,
    num_steps=512,
    graph_cache=cache  # Use captured graph for 10x speedup
)

# Output shape: (batch_size, num_steps, H)
print(output.shape)  # torch.Size([4, 512, 64])

CUDA Graph Optimization

Why CUDA Graphs?

Standard autoregressive generation in PyTorch has significant CPU overhead for each step. CUDA graphs eliminate this by:
  1. Recording the computational graph once during capture_graph()
  2. Replaying the pre-recorded graph for each timestep with zero CPU overhead
This provides >10x speedup for inference!

How It Works

import torch
from lrnnx.utils.generation import capture_graph, generate
from lrnnx.models.lti import LRU

model = LRU(d_model=64, d_state=64).cuda().eval()

# One-time graph capture
cache = capture_graph(model, batch_size=4, H=64)

# Seed token
x0 = torch.randn(4, 64, device="cuda")

# Fast generation with graph replay
output = generate(model, x0, num_steps=512, graph_cache=cache)

Complete Example

import torch
from lrnnx.models.lti import LRU
from lrnnx.utils.generation import capture_graph, generate

# Configuration
batch_size = 4
d_model = 64
d_state = 64
num_steps = 512

# Initialize model
model = LRU(d_model=d_model, d_state=d_state).cuda()
model.eval()  # Important: set to eval mode

# Capture CUDA graph (do this once)
print("Capturing CUDA graph...")
cache = capture_graph(model, batch_size=batch_size, H=d_model)

# Create seed input
x0 = torch.randn(batch_size, d_model, device="cuda")

# Generate with CUDA graph
print("Generating...")
with torch.inference_mode():
    output = generate(model, x0, num_steps=num_steps, graph_cache=cache)

print(f"Generated output shape: {output.shape}")
# Output: Generated output shape: torch.Size([4, 512, 64])

Event-Based Inference

Some models support event-driven timesteps during generation:
from lrnnx.models.lti import S5
from lrnnx.utils.generation import capture_graph, generate

model = S5(d_model=64, d_state=64).cuda().eval()

# Capture with event mode enabled
cache = capture_graph(
    model,
    batch_size=4,
    H=64,
    event_mode=True  # Enable event-driven timesteps
)

x0 = torch.randn(4, 64, device="cuda")

# Provide integration timestep (reused at every step)
integration_timesteps = torch.ones(4, 1, device="cuda") * 0.1

output = generate(
    model,
    x0,
    num_steps=512,
    graph_cache=cache,
    integration_timesteps=integration_timesteps
)
When using integration_timesteps, you must capture the graph with event_mode=True.

Benchmarking Inference

The library includes built-in benchmarks to measure inference performance:
from benchmarks.benchmark_inference import benchmark_sequence_length
from lrnnx.models.lti import LRU

def model_fn():
    return LRU(d_model=128, d_state=64).cuda().eval()

# Benchmark CUDA-graph inference across sequence lengths
results = benchmark_sequence_length(
    model_fn,
    seq_lengths=[64, 128, 256, 512, 1024, 2048],
    batch_size=32,
    repeats=5
)

for seq_len, times in results.items():
    avg_time = sum(times) / len(times)
    print(f"Seq len {seq_len}: {avg_time:.2f} ms")
See benchmarks/benchmark_inference.py for complete benchmarking utilities including:
  • benchmark_sequence_length() - Vary generation length
  • benchmark_model_dimension() - Vary model size
  • benchmark_batch_size() - Vary batch size

API Reference

capture_graph()

Captures a CUDA graph for the model’s single-step recurrence. Parameters:
  • model (LTI_LRNN | LTV_LRNN) - Model on CUDA in eval mode
  • batch_size (int) - Batch size to capture for
  • H (int) - Model input/output dimension (d_model)
  • max_seqlen (int, optional) - Maximum sequence length, default: 1
  • event_mode (bool, optional) - Enable event-driven timesteps, default: False
  • device (torch.device, optional) - CUDA device, inferred from model if None
  • n_warmups (int, optional) - Warmup iterations before capture, default: 3
Returns:
  • CUDAGraphStepCache - Opaque cache object to pass to generate()
Example:
cache = capture_graph(model, batch_size=4, H=64)

generate()

Autoregressive generation with optional CUDA graph acceleration. Parameters:
  • model (LTI_LRNN | LTV_LRNN) - Model on CUDA in eval mode
  • x (torch.Tensor) - Seed input, shape (batch, H)
  • num_steps (int) - Number of autoregressive steps
  • graph_cache (CUDAGraphStepCache, optional) - Pre-captured graph from capture_graph(), default: None
  • integration_timesteps (torch.Tensor, optional) - Integration timestep shape (batch, 1) for event models, default: None
Returns:
  • torch.Tensor - Generated sequence, shape (batch, num_steps, H)
Example:
output = generate(model, x0, num_steps=512, graph_cache=cache)

Performance Tips

1

Always use CUDA graphs

Capture once, reuse for all generations with the same batch size:
cache = capture_graph(model, batch_size=4, H=64)
# Reuse cache for multiple generations
out1 = generate(model, x1, num_steps=100, graph_cache=cache)
out2 = generate(model, x2, num_steps=200, graph_cache=cache)
2

Recapture for different batch sizes

CUDA graphs are fixed-shape, so create separate caches:
cache_b4 = capture_graph(model, batch_size=4, H=64)
cache_b8 = capture_graph(model, batch_size=8, H=64)

out_b4 = generate(model, x4, num_steps=100, graph_cache=cache_b4)
out_b8 = generate(model, x8, num_steps=100, graph_cache=cache_b8)
3

Use inference_mode for best performance

with torch.inference_mode():
    output = generate(model, x0, num_steps=512, graph_cache=cache)

Troubleshooting

Batch Size Mismatch

If you get an error about batch size mismatch:
ValueError: Batch size 8 != captured 4. Re-capture with capture_graph(model, batch_size=8).
Solution: Recapture the graph with the correct batch size:
cache = capture_graph(model, batch_size=8, H=64)

Memory Issues

If graph capture fails due to memory:
import gc
import torch

# Free memory before capture
gc.collect()
torch.cuda.empty_cache()

cache = capture_graph(model, batch_size=4, H=64)

Next Steps

Training Guide

Learn how to train lrnnx models

Custom Kernels

Understand the high-performance CUDA kernels

Build docs developers (and LLMs) love