Overview
Linear RNNs in PyTorch require special handling during inference. Following the approach from Mamba , lrnnx implements CUDA graphs-based inference which reduces CPU overhead and provides >10x speedup compared to a simple for loop.
The generation API is located in lrnnx/utils/generation.py and works with all models in the library.
Quick Start
Import the generation API
import torch
from lrnnx.utils.generation import capture_graph, generate
from lrnnx.models.lti import LRU
Create and prepare model
# Initialize model in eval mode on CUDA
model = LRU( d_model = 64 , d_state = 64 ).cuda().eval()
# Model configuration
batch_size = 4
H = 64 # d_model
Capture CUDA graph (one-time setup)
# Capture graph once - reuse for all generations
cache = capture_graph(model, batch_size = batch_size, H = H)
Graph capture is a one-time operation. The captured graph is tied to the specific batch size.
Generate sequences
# Create seed input (B, H)
x0 = torch.randn(batch_size, H, device = "cuda" )
# Generate 512 steps with CUDA graph replay
output = generate(
model,
x0,
num_steps = 512 ,
graph_cache = cache # Use captured graph for 10x speedup
)
# Output shape: (batch_size, num_steps, H)
print (output.shape) # torch.Size([4, 512, 64])
CUDA Graph Optimization
Why CUDA Graphs?
Standard autoregressive generation in PyTorch has significant CPU overhead for each step. CUDA graphs eliminate this by:
Recording the computational graph once during capture_graph()
Replaying the pre-recorded graph for each timestep with zero CPU overhead
This provides >10x speedup for inference!
How It Works
With CUDA Graph (Fast)
Without CUDA Graph (Slow)
import torch
from lrnnx.utils.generation import capture_graph, generate
from lrnnx.models.lti import LRU
model = LRU( d_model = 64 , d_state = 64 ).cuda().eval()
# One-time graph capture
cache = capture_graph(model, batch_size = 4 , H = 64 )
# Seed token
x0 = torch.randn( 4 , 64 , device = "cuda" )
# Fast generation with graph replay
output = generate(model, x0, num_steps = 512 , graph_cache = cache)
Complete Example
import torch
from lrnnx.models.lti import LRU
from lrnnx.utils.generation import capture_graph, generate
# Configuration
batch_size = 4
d_model = 64
d_state = 64
num_steps = 512
# Initialize model
model = LRU( d_model = d_model, d_state = d_state).cuda()
model.eval() # Important: set to eval mode
# Capture CUDA graph (do this once)
print ( "Capturing CUDA graph..." )
cache = capture_graph(model, batch_size = batch_size, H = d_model)
# Create seed input
x0 = torch.randn(batch_size, d_model, device = "cuda" )
# Generate with CUDA graph
print ( "Generating..." )
with torch.inference_mode():
output = generate(model, x0, num_steps = num_steps, graph_cache = cache)
print ( f "Generated output shape: { output.shape } " )
# Output: Generated output shape: torch.Size([4, 512, 64])
Event-Based Inference
Some models support event-driven timesteps during generation:
from lrnnx.models.lti import S5
from lrnnx.utils.generation import capture_graph, generate
model = S5( d_model = 64 , d_state = 64 ).cuda().eval()
# Capture with event mode enabled
cache = capture_graph(
model,
batch_size = 4 ,
H = 64 ,
event_mode = True # Enable event-driven timesteps
)
x0 = torch.randn( 4 , 64 , device = "cuda" )
# Provide integration timestep (reused at every step)
integration_timesteps = torch.ones( 4 , 1 , device = "cuda" ) * 0.1
output = generate(
model,
x0,
num_steps = 512 ,
graph_cache = cache,
integration_timesteps = integration_timesteps
)
When using integration_timesteps, you must capture the graph with event_mode=True.
Benchmarking Inference
The library includes built-in benchmarks to measure inference performance:
from benchmarks.benchmark_inference import benchmark_sequence_length
from lrnnx.models.lti import LRU
def model_fn ():
return LRU( d_model = 128 , d_state = 64 ).cuda().eval()
# Benchmark CUDA-graph inference across sequence lengths
results = benchmark_sequence_length(
model_fn,
seq_lengths = [ 64 , 128 , 256 , 512 , 1024 , 2048 ],
batch_size = 32 ,
repeats = 5
)
for seq_len, times in results.items():
avg_time = sum (times) / len (times)
print ( f "Seq len { seq_len } : { avg_time :.2f} ms" )
See benchmarks/benchmark_inference.py for complete benchmarking utilities including:
benchmark_sequence_length() - Vary generation length
benchmark_model_dimension() - Vary model size
benchmark_batch_size() - Vary batch size
API Reference
capture_graph()
Captures a CUDA graph for the model’s single-step recurrence.
Parameters:
model (LTI_LRNN | LTV_LRNN) - Model on CUDA in eval mode
batch_size (int) - Batch size to capture for
H (int) - Model input/output dimension (d_model)
max_seqlen (int, optional) - Maximum sequence length, default: 1
event_mode (bool, optional) - Enable event-driven timesteps, default: False
device (torch.device, optional) - CUDA device, inferred from model if None
n_warmups (int, optional) - Warmup iterations before capture, default: 3
Returns:
CUDAGraphStepCache - Opaque cache object to pass to generate()
Example:
cache = capture_graph(model, batch_size = 4 , H = 64 )
generate()
Autoregressive generation with optional CUDA graph acceleration.
Parameters:
model (LTI_LRNN | LTV_LRNN) - Model on CUDA in eval mode
x (torch.Tensor) - Seed input, shape (batch, H)
num_steps (int) - Number of autoregressive steps
graph_cache (CUDAGraphStepCache, optional) - Pre-captured graph from capture_graph(), default: None
integration_timesteps (torch.Tensor, optional) - Integration timestep shape (batch, 1) for event models, default: None
Returns:
torch.Tensor - Generated sequence, shape (batch, num_steps, H)
Example:
output = generate(model, x0, num_steps = 512 , graph_cache = cache)
Always use CUDA graphs
Capture once, reuse for all generations with the same batch size: cache = capture_graph(model, batch_size = 4 , H = 64 )
# Reuse cache for multiple generations
out1 = generate(model, x1, num_steps = 100 , graph_cache = cache)
out2 = generate(model, x2, num_steps = 200 , graph_cache = cache)
Recapture for different batch sizes
CUDA graphs are fixed-shape, so create separate caches: cache_b4 = capture_graph(model, batch_size = 4 , H = 64 )
cache_b8 = capture_graph(model, batch_size = 8 , H = 64 )
out_b4 = generate(model, x4, num_steps = 100 , graph_cache = cache_b4)
out_b8 = generate(model, x8, num_steps = 100 , graph_cache = cache_b8)
Use inference_mode for best performance
with torch.inference_mode():
output = generate(model, x0, num_steps = 512 , graph_cache = cache)
Troubleshooting
Batch Size Mismatch
If you get an error about batch size mismatch:
ValueError : Batch size 8 != captured 4 . Re - capture with capture_graph(model, batch_size = 8 ).
Solution: Recapture the graph with the correct batch size:
cache = capture_graph(model, batch_size = 8 , H = 64 )
Memory Issues
If graph capture fails due to memory:
import gc
import torch
# Free memory before capture
gc.collect()
torch.cuda.empty_cache()
cache = capture_graph(model, batch_size = 4 , H = 64 )
Next Steps
Training Guide Learn how to train lrnnx models
Custom Kernels Understand the high-performance CUDA kernels