Skip to main content

Overview

LRU (Linear Recurrent Unit) is a minimal diagonal state space model that achieves strong performance with very few parameters. It uses a diagonal complex matrix for state transitions, making it extremely efficient while maintaining competitive accuracy. LRU is characterized by its simplicity and efficiency, making it an excellent choice when parameter count and computational cost matter.

Paper Reference

Resurrecting Recurrent Neural Networks for Long Sequences https://arxiv.org/abs/2303.06349

Installation

from lrnnx.models.lti import LRU

Parameters

d_model
int
required
Model dimension - size of input and output features.
d_state
int
required
State dimension. Number of complex diagonal elements in the recurrent matrix.
r_min
float
default:"0"
Minimum radius for Lambda (eigenvalue) initialization. Lambda values are initialized uniformly on a ring between r_min and r_max.
r_max
float
default:"1"
Maximum radius for Lambda initialization. Values closer to 1 enable longer memory.
max_phase
float
default:"2 * math.pi"
Maximum phase angle for Lambda initialization. Controls the frequency range of the recurrent dynamics.

Usage Example

Basic Usage

import torch
from lrnnx.models.lti import LRU

# Create LRU model
model = LRU(d_model=64, d_state=64)

# Forward pass
x = torch.randn(2, 128, 64)  # (batch, length, features)
y = model(x)

print(y.shape)  # torch.Size([2, 128, 64])

Custom Initialization Range

import math
import torch
from lrnnx.models.lti import LRU

# Configure eigenvalue initialization
model = LRU(
    d_model=128,
    d_state=128,
    r_min=0.8,          # Higher min radius for longer memory
    r_max=0.99,         # Close to 1 for very long dependencies
    max_phase=math.pi,  # Restrict phase range
)

x = torch.randn(4, 512, 128)
y = model(x)

Autoregressive Inference

import torch
from lrnnx.models.lti import LRU

model = LRU(d_model=64, d_state=64)
batch_size = 2

# Allocate inference cache
cache = model.allocate_inference_cache(batch_size=batch_size)

# Generate sequence step-by-step
for t in range(100):
    x_t = torch.randn(batch_size, 64)  # Single timestep
    y_t, cache = model.step(x_t, cache)
    # y_t.shape: (batch_size, 64)

Stacked LRU Architecture

import torch
import torch.nn as nn
from lrnnx.models.lti import LRU

class DeepLRU(nn.Module):
    def __init__(self, d_model=256, d_state=256, n_layers=6):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(
                LRU(d_model=d_model, d_state=d_state),
                nn.LayerNorm(d_model),
                nn.Linear(d_model, d_model * 4),
                nn.GELU(),
                nn.Linear(d_model * 4, d_model),
            )
            for _ in range(n_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x) + x  # Residual connection
        return x

model = DeepLRU()
x = torch.randn(2, 1024, 256)
y = model(x)

Key Features

Diagonal Parameterization

LRU uses a pure diagonal complex matrix:
Lambda = exp(-exp(nu_log) + i * exp(theta_log))
Where:
  • nu_log: Log of radius decay (real part)
  • theta_log: Log of phase/frequency (imaginary part)
  • Result is complex eigenvalues on a ring

Minimal Parameters

LRU has the fewest parameters among SSMs:
ComponentShapeParameters
nu_log(d_state,)d_state
theta_log(d_state,)d_state
B_re, B_im(d_state, d_model)2 * d_state * d_model
C_re, C_im(d_model, d_state)2 * d_model * d_state
D(d_model,)d_model
gamma_log(d_state,)d_state
Total: ~4 * d_state * d_model + 3 * d_state + d_model

Normalization

LRU uses a learned normalization factor gamma:
gamma = exp(gamma_log)
B_norm = gamma * B  # Normalize input
Initially: gamma = sqrt(1 - |Lambda|^2) (can also be kept fixed)

FFT Convolution

LRU computes the kernel as powers of Lambda:
K[n, l] = Lambda[n] ** l  # Element-wise power
y = opt_ssm_forward(x, K, B_norm, C)

Initialization Details

Lambda (Eigenvalues)

Initialized uniformly on a ring:
# Radius distributed between r_min and r_max
u1 = uniform(0, 1)
radius = sqrt(u1 * (r_max^2 - r_min^2) + r_min^2)

# Phase distributed between 0 and max_phase  
u2 = uniform(0, 1)
phase = max_phase * u2

Lambda = radius * exp(i * phase)

Input/Output Projections

  • B (input): Glorot initialization scaled by 1/√(2*d_model)
  • C (output): Glorot initialization scaled by 1/√d_state
  • D (skip): Random initialization

Normalization Factor

Initialized to maintain unit variance:
gamma = sqrt(1 - |Lambda|^2)

Recurrent Dynamics

Forward Pass (Convolution Mode)

# Compute kernel
K = Lambda^[0, 1, ..., L-1]  # Powers of diagonal
B_norm = gamma * B

# FFT convolution
y = FFT_conv(x, K @ B_norm, C) + x * D

Step (Recurrent Mode)

# Update state
state_{t+1} = Lambda * state_t + B_norm @ x_t

# Output
y_t = (C @ state_{t+1}).real + D * x_t

Performance Characteristics

Speed

Fastest SSM:
  • Minimal parameters → less computation
  • Diagonal structure → element-wise ops
  • No gating or convolutions

Memory

Lowest memory usage:
  • Smallest parameter count
  • Simple state representation
  • Efficient caching

Accuracy

Competitive performance:
  • Despite simplicity, matches more complex models
  • Strong on long-range tasks
  • Scales well with depth

Performance Tips

For very long sequences, set r_max close to 1.0 (e.g., 0.99) to enable longer memory.
LRU works best in deep architectures. Stack multiple LRU layers with residual connections for best results.
LRU is an LTI model and does not support input-dependent dynamics. For selective processing, use Mamba.

When to Use LRU

Use LRU when:
  • Parameter efficiency is critical
  • You need fast training/inference
  • Working with very long sequences
  • You want a simple baseline
  • Mobile/edge deployment is needed
Consider alternatives when:
  • You need maximum expressiveness → S4
  • You want input-dependent selection → Mamba
  • You need gating mechanisms → RG-LRU

Comparison with Other Models

ModelParametersSpeedMemoryPerformance
LRU⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
S4⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
S4D⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
Mamba⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
LRU offers the best efficiency-performance tradeoff.

Advanced Usage

Custom Frequency Distribution

import math
import torch
from lrnnx.models.lti import LRU

# Focus on low frequencies
model_low_freq = LRU(
    d_model=64,
    d_state=64,
    r_max=0.99,
    max_phase=math.pi / 4,  # Limit to low frequencies
)

# Full frequency range
model_full_freq = LRU(
    d_model=64,
    d_state=64,
    r_max=0.95,
    max_phase=2 * math.pi,  # All frequencies
)

Hybrid Architecture

import torch.nn as nn
from lrnnx.models.lti import LRU

class HybridModel(nn.Module):
    def __init__(self, d_model=256, d_state=256):
        super().__init__()
        # Low-frequency LRU
        self.lru_low = LRU(
            d_model=d_model,
            d_state=d_state // 2,
            max_phase=math.pi / 2,
        )
        # High-frequency LRU
        self.lru_high = LRU(
            d_model=d_model,
            d_state=d_state // 2,
            max_phase=2 * math.pi,
        )
        self.combine = nn.Linear(d_model * 2, d_model)
    
    def forward(self, x):
        y_low = self.lru_low(x)
        y_high = self.lru_high(x)
        y = self.combine(torch.cat([y_low, y_high], dim=-1))
        return y + x

See Also

  • RG-LRU - Gated LTV variant
  • S5 - Simple SSM with multiple discretizations
  • S4D - Diagonal S4 with more features

Build docs developers (and LLMs) love