Skip to main content

Overview

RG-LRU (Recurrent Gated Linear Recurrent Unit) is a gated state space model from the Griffin architecture that combines the efficiency of diagonal LRU with input-dependent gating. It achieves competitive performance with Mamba while being simpler and faster. RG-LRU uses recurrent gates to modulate state updates and input gates to filter inputs, enabling selective processing without the complexity of full input-dependent matrices.

Paper Reference

Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models https://arxiv.org/abs/2402.19427

Installation

from lrnnx.models.ltv import RGLRU

Parameters

d_model
int
required
Model dimension - size of input and output features.
d_conv
int
default:"4"
Temporal convolution kernel size. Typically 3-4 for local mixing.
expand
int
default:"1"
Expansion factor for inner dimension. d_inner = expand * d_model. Usually 1 for RG-LRU.
c
float
default:"8.0"
Fixed scalar for recurrent gate scaling. Controls the range of gate values.
a_init_range
Tuple[float, float]
default:"(0.9, 0.999)"
Tuple (lo, hi) for initializing the recurrence base a uniformly in this range within (0, 1). Higher values enable longer memory.
conv_bias
bool
default:"True"
Whether the Conv1d layer uses a bias term.
bias
bool
default:"False"
Whether linear projections use bias.
use_fast_path
bool
default:"True"
Use fused CUDA kernel when available for significant speedup.
layer_idx
int
default:"None"
Layer index for multi-layer caching during inference.
device
torch.device
default:"None"
Device for model parameters.
dtype
torch.dtype
default:"None"
Data type for model parameters.

Usage Example

Basic Usage

import torch
from lrnnx.models.ltv import RGLRU

# Create RG-LRU model
model = RGLRU(d_model=64, d_conv=4)

# Forward pass
x = torch.randn(2, 128, 64)  # (batch, length, features)
y = model(x)

print(y.shape)  # torch.Size([2, 128, 64])

Language Modeling Configuration

import torch
from lrnnx.models.ltv import RGLRU

# Typical configuration for language modeling
model = RGLRU(
    d_model=768,
    d_conv=4,
    expand=1,
    c=8.0,
    a_init_range=(0.9, 0.999),
    use_fast_path=True,
)

x = torch.randn(4, 2048, 768)
y = model(x)

Autoregressive Inference

import torch
from lrnnx.models.ltv import RGLRU

model = RGLRU(d_model=256, d_conv=4)
batch_size = 2
max_seqlen = 1024

# Allocate inference cache
cache = model.allocate_inference_cache(
    batch_size=batch_size,
    max_seqlen=max_seqlen
)

# Initialize offset
cache["seqlen_offset"] = 0

# Generate sequence token-by-token
for t in range(100):
    x_t = torch.randn(batch_size, 1, 256)
    y_t, cache = model.step(x_t, cache)
    # y_t.shape: (batch_size, 1, 256)

Custom Memory Range

import torch
from lrnnx.models.ltv import RGLRU

# Configure for very long memory
model = RGLRU(
    d_model=128,
    d_conv=4,
    a_init_range=(0.95, 0.9995),  # Higher range for longer memory
    c=10.0,  # Larger scaling
)

x = torch.randn(2, 4096, 128)
y = model(x)

Key Features

Gated Recurrence

RG-LRU uses two gates:
# Recurrent gate (controls state decay)
recurrent_gate = sigmoid(recurrent_proj(x))  # (B, D, L)
delta = c * recurrent_gate  # Scale by constant c

# Input gate (filters input)
input_gate = sigmoid(input_proj(x))  # (B, D, L)
u_gated = input_gate * x  # Gated input

Diagonal Recurrence

Like LRU, uses diagonal state transitions:
# Learnable base in (0, 1)
a = sigmoid(a_log)  # (d_inner, 1)

# Gated update
a_bar = a ** delta  # Element-wise power
sqrt_term = sqrt(1 - a_bar^2)

# State update
state_{t+1} = a_bar * state_t + sqrt_term * u_gated
The sqrt_term normalizes to maintain unit variance.

Two-Stream Architecture

# Stream 1: Gate pathway
gate = gelu(gate_proj(x))  # (B, L, D)

# Stream 2: RG-LRU pathway  
x -> conv1d -> gates -> rglru_scan -> y

# Merge
out = gate * y

Architecture Details

Forward Pass

  1. Gate Stream: gate = gelu(gate_proj(x))
    • Simple gating pathway
  2. RG-LRU Stream:
    • Input Projection: x = in_proj(x)
    • Conv1d: x = conv1d(x) (causal)
    • Gate Projections: Compute recurrent_gate and input_gate
    • Gated Scan: Update state with gated recurrence
  3. Merge: out = out_proj(gate * y)
    • Multiply streams and project

Recurrent Update

The core RG-LRU recurrence:
# Compute a_bar from base and gate
a_bar = a ** (c * recurrent_gate)  # (B, D, L)

# Normalization factor
sqrt_term = sqrt(1 - a_bar^2)

# Update
state_{t+1} = a_bar * state_t + sqrt_term * (input_gate * x_t)

# Output (sum over dstate=1)
y_t = state_{t+1}
Note: RG-LRU uses d_state=1 (scalar state per channel).

Initialization

  • a (base): Uniform in [a_init_range[0], a_init_range[1]]
    • Typically (0.9, 0.999)
    • Higher values → longer memory
  • Gate projections: Standard initialization
    • With bias (important for gates)
  • c (scaling): Fixed constant (default 8.0)
    • Not learned
    • Controls gate range

Performance Characteristics

Speed

Fast:
  • Simpler than Mamba
  • Diagonal structure (element-wise ops)
  • Efficient CUDA kernels available

Memory

Efficient:
  • d_state=1 (minimal state)
  • No large intermediate tensors
  • Small cache for inference

Accuracy

Competitive:
  • Close to Mamba on many tasks
  • Strong on language modeling
  • Scales well with model size

Comparison with Other Models

ModelGatingState DimSpeedAccuracy
RG-LRU✅ Recurrent1⭐⭐⭐⭐⭐⭐⭐⭐⭐
Mamba✅ Selective16⭐⭐⭐⭐⭐⭐⭐⭐⭐
LRUN⭐⭐⭐⭐⭐⭐⭐⭐
S7N⭐⭐⭐⭐⭐⭐⭐
RG-LRU offers the best speed-accuracy tradeoff among gated models.

Performance Tips

RG-LRU is a great Mamba alternative when you want similar performance with simpler architecture and faster training.
The c parameter controls gate scaling. Default c=8.0 works well, but you can experiment with 4.0-16.0 for different behaviors.
Install causal-conv1d for fast Conv1d:
pip install causal-conv1d>=1.1.0
Otherwise falls back to slower PyTorch implementation.

When to Use RG-LRU

Use RG-LRU when:
  • You want gated/selective processing
  • Simpler than Mamba is desired
  • Fast training is important
  • Working on language modeling
  • You need a strong baseline
Consider alternatives when:
  • You need maximum accuracy → Mamba
  • No gating needed → LRU
  • Want input-dependent B, C → Mamba or S7

Advanced Usage

Stacked Architecture

import torch.nn as nn
from lrnnx.models.ltv import RGLRU

class StackedRGLRU(nn.Module):
    def __init__(self, d_model=256, n_layers=12):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'rglru': RGLRU(d_model=d_model, d_conv=4),
                'norm': nn.LayerNorm(d_model),
                'mlp': nn.Sequential(
                    nn.Linear(d_model, d_model * 4),
                    nn.GELU(),
                    nn.Linear(d_model * 4, d_model),
                ),
            })
            for _ in range(n_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            # RG-LRU block
            x = layer['rglru'](x) + x
            x = layer['norm'](x)
            # MLP block
            x = layer['mlp'](x) + x
        return x

Hybrid with Attention

import torch.nn as nn
from lrnnx.models.ltv import RGLRU

class HybridBlock(nn.Module):
    def __init__(self, d_model=256, n_heads=8):
        super().__init__()
        self.rglru = RGLRU(d_model=d_model, d_conv=4)
        self.attn = nn.MultiheadAttention(
            d_model, n_heads, batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x):
        # RG-LRU for long-range
        x = self.rglru(x) + x
        x = self.norm1(x)
        # Attention for local
        attn_out, _ = self.attn(x, x, x)
        x = attn_out + x
        x = self.norm2(x)
        return x

Custom Memory Configuration

import torch
from lrnnx.models.ltv import RGLRU

# Short memory (fast decay)
model_short = RGLRU(
    d_model=128,
    a_init_range=(0.7, 0.9),
    c=4.0,
)

# Long memory (slow decay)
model_long = RGLRU(
    d_model=128,
    a_init_range=(0.95, 0.9995),
    c=16.0,
)

RG-LRU vs LRU

Key Differences

AspectLRURG-LRU
TypeLTILTV
GatingNoneRecurrent + Input
State dimd_state1 (per channel)
DynamicsFixedInput-dependent
Expressiveness⭐⭐⭐⭐⭐⭐⭐
Speed (train)⭐⭐⭐⭐⭐⭐⭐⭐⭐
Use caseGeneralLanguage

When to Use Which

Use LRU if:
  • Simplicity is key
  • LTI model is sufficient
  • Maximum speed needed
Use RG-LRU if:
  • Need selective processing
  • Language modeling
  • Input-dependent dynamics

See Also

  • Mamba - More complex selective SSM
  • LRU - Non-gated LTI variant
  • S7 - Alternative selective model

Build docs developers (and LLMs) love