Overview
LRU (Linear Recurrent Unit) is a minimal diagonal state space model that achieves strong performance with very few parameters. It uses a diagonal complex matrix for state transitions, making it extremely efficient while maintaining competitive accuracy.
LRU is characterized by its simplicity and efficiency, making it an excellent choice when parameter count and computational cost matter.
Paper Reference
Resurrecting Recurrent Neural Networks for Long Sequences
https://arxiv.org/abs/2303.06349
Installation
from lrnnx.models.lti import LRU
Parameters
Model dimension - size of input and output features.
State dimension. Number of complex diagonal elements in the recurrent matrix.
Minimum radius for Lambda (eigenvalue) initialization. Lambda values are initialized uniformly on a ring between r_min and r_max.
Maximum radius for Lambda initialization. Values closer to 1 enable longer memory.
max_phase
float
default:"2 * math.pi"
Maximum phase angle for Lambda initialization. Controls the frequency range of the recurrent dynamics.
Usage Example
Basic Usage
import torch
from lrnnx.models.lti import LRU
# Create LRU model
model = LRU(d_model=64, d_state=64)
# Forward pass
x = torch.randn(2, 128, 64) # (batch, length, features)
y = model(x)
print(y.shape) # torch.Size([2, 128, 64])
Custom Initialization Range
import math
import torch
from lrnnx.models.lti import LRU
# Configure eigenvalue initialization
model = LRU(
d_model=128,
d_state=128,
r_min=0.8, # Higher min radius for longer memory
r_max=0.99, # Close to 1 for very long dependencies
max_phase=math.pi, # Restrict phase range
)
x = torch.randn(4, 512, 128)
y = model(x)
Autoregressive Inference
import torch
from lrnnx.models.lti import LRU
model = LRU(d_model=64, d_state=64)
batch_size = 2
# Allocate inference cache
cache = model.allocate_inference_cache(batch_size=batch_size)
# Generate sequence step-by-step
for t in range(100):
x_t = torch.randn(batch_size, 64) # Single timestep
y_t, cache = model.step(x_t, cache)
# y_t.shape: (batch_size, 64)
Stacked LRU Architecture
import torch
import torch.nn as nn
from lrnnx.models.lti import LRU
class DeepLRU(nn.Module):
def __init__(self, d_model=256, d_state=256, n_layers=6):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
LRU(d_model=d_model, d_state=d_state),
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
)
for _ in range(n_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x) + x # Residual connection
return x
model = DeepLRU()
x = torch.randn(2, 1024, 256)
y = model(x)
Key Features
Diagonal Parameterization
LRU uses a pure diagonal complex matrix:
Lambda = exp(-exp(nu_log) + i * exp(theta_log))
Where:
nu_log: Log of radius decay (real part)
theta_log: Log of phase/frequency (imaginary part)
- Result is complex eigenvalues on a ring
Minimal Parameters
LRU has the fewest parameters among SSMs:
| Component | Shape | Parameters |
|---|
| nu_log | (d_state,) | d_state |
| theta_log | (d_state,) | d_state |
| B_re, B_im | (d_state, d_model) | 2 * d_state * d_model |
| C_re, C_im | (d_model, d_state) | 2 * d_model * d_state |
| D | (d_model,) | d_model |
| gamma_log | (d_state,) | d_state |
Total: ~4 * d_state * d_model + 3 * d_state + d_model
Normalization
LRU uses a learned normalization factor gamma:
gamma = exp(gamma_log)
B_norm = gamma * B # Normalize input
Initially: gamma = sqrt(1 - |Lambda|^2) (can also be kept fixed)
FFT Convolution
LRU computes the kernel as powers of Lambda:
K[n, l] = Lambda[n] ** l # Element-wise power
y = opt_ssm_forward(x, K, B_norm, C)
Initialization Details
Lambda (Eigenvalues)
Initialized uniformly on a ring:
# Radius distributed between r_min and r_max
u1 = uniform(0, 1)
radius = sqrt(u1 * (r_max^2 - r_min^2) + r_min^2)
# Phase distributed between 0 and max_phase
u2 = uniform(0, 1)
phase = max_phase * u2
Lambda = radius * exp(i * phase)
- B (input): Glorot initialization scaled by
1/√(2*d_model)
- C (output): Glorot initialization scaled by
1/√d_state
- D (skip): Random initialization
Normalization Factor
Initialized to maintain unit variance:
gamma = sqrt(1 - |Lambda|^2)
Recurrent Dynamics
Forward Pass (Convolution Mode)
# Compute kernel
K = Lambda^[0, 1, ..., L-1] # Powers of diagonal
B_norm = gamma * B
# FFT convolution
y = FFT_conv(x, K @ B_norm, C) + x * D
Step (Recurrent Mode)
# Update state
state_{t+1} = Lambda * state_t + B_norm @ x_t
# Output
y_t = (C @ state_{t+1}).real + D * x_t
Speed
✅ Fastest SSM:
- Minimal parameters → less computation
- Diagonal structure → element-wise ops
- No gating or convolutions
Memory
✅ Lowest memory usage:
- Smallest parameter count
- Simple state representation
- Efficient caching
Accuracy
✅ Competitive performance:
- Despite simplicity, matches more complex models
- Strong on long-range tasks
- Scales well with depth
For very long sequences, set r_max close to 1.0 (e.g., 0.99) to enable longer memory.
LRU works best in deep architectures. Stack multiple LRU layers with residual connections for best results.
LRU is an LTI model and does not support input-dependent dynamics. For selective processing, use Mamba.
When to Use LRU
✅ Use LRU when:
- Parameter efficiency is critical
- You need fast training/inference
- Working with very long sequences
- You want a simple baseline
- Mobile/edge deployment is needed
❌ Consider alternatives when:
- You need maximum expressiveness → S4
- You want input-dependent selection → Mamba
- You need gating mechanisms → RG-LRU
Comparison with Other Models
| Model | Parameters | Speed | Memory | Performance |
|---|
| LRU | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ |
| S4 | ⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ |
| S4D | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ |
| Mamba | ⭐⭐ | ⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐⭐ |
LRU offers the best efficiency-performance tradeoff.
Advanced Usage
Custom Frequency Distribution
import math
import torch
from lrnnx.models.lti import LRU
# Focus on low frequencies
model_low_freq = LRU(
d_model=64,
d_state=64,
r_max=0.99,
max_phase=math.pi / 4, # Limit to low frequencies
)
# Full frequency range
model_full_freq = LRU(
d_model=64,
d_state=64,
r_max=0.95,
max_phase=2 * math.pi, # All frequencies
)
Hybrid Architecture
import torch.nn as nn
from lrnnx.models.lti import LRU
class HybridModel(nn.Module):
def __init__(self, d_model=256, d_state=256):
super().__init__()
# Low-frequency LRU
self.lru_low = LRU(
d_model=d_model,
d_state=d_state // 2,
max_phase=math.pi / 2,
)
# High-frequency LRU
self.lru_high = LRU(
d_model=d_model,
d_state=d_state // 2,
max_phase=2 * math.pi,
)
self.combine = nn.Linear(d_model * 2, d_model)
def forward(self, x):
y_low = self.lru_low(x)
y_high = self.lru_high(x)
y = self.combine(torch.cat([y_low, y_high], dim=-1))
return y + x
See Also
- RG-LRU - Gated LTV variant
- S5 - Simple SSM with multiple discretizations
- S4D - Diagonal S4 with more features