Skip to main content

Overview

RGLRU (Recurrent Gated Linear Recurrent Unit) is an LTV model based on the Griffin architecture. It uses a gated linear recurrence with learnable parameters for efficient sequence modeling. Key features:
  • Simple 1D state space (d_state=1)
  • Dual-stream architecture: gate path and recurrent path
  • Learnable recurrence base parameter in (0, 1)
  • Causal temporal convolution for local context
  • Efficient inference with state caching

Import

from lrnnx.models.ltv import RGLRU

Class Signature

class RGLRU(LTV_LRNN)

Constructor

__init__

def __init__(
    self,
    d_model: int,
    d_conv: int = 4,
    expand: int = 1,
    c: float = 8.0,
    a_init_range: Tuple[float, float] = (0.9, 0.999),
    conv_bias: bool = True,
    bias: bool = False,
    use_fast_path: bool = True,
    layer_idx: Optional[int] = None,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
)
Initialize RG-LRU block.
d_model
int
required
Model dimension (input/output dimension).
d_conv
int
default:"4"
Temporal convolution kernel size for local context modeling.
expand
int
default:"1"
Expansion factor for inner dimension. Inner dimension = d_model * expand.
c
float
default:"8.0"
Fixed scalar for recurrent gate scaling. Controls the effective timescale of the recurrence.
a_init_range
Tuple[float, float]
default:"(0.9, 0.999)"
Tuple (lo, hi) defining the initialization range for the recurrence base parameter a in (0, 1). Values closer to 1 create longer-range dependencies.
conv_bias
bool
default:"True"
Whether the Conv1D layer uses a bias term.
bias
bool
default:"False"
Whether Linear projections use bias terms.
use_fast_path
bool
default:"True"
Use the fused CUDA kernel when available for improved performance.
layer_idx
int | None
default:"None"
Layer index for multi-layer caching in stacked architectures.
device
torch.device | None
default:"None"
Device for parameters. If None, uses default device.
dtype
torch.dtype | None
default:"None"
Data type for parameters. If None, uses default dtype.

Methods

forward

def forward(
    self,
    hidden_states: Tensor,
    integration_timesteps: Optional[Tensor] = None,
    lengths: Optional[Tensor] = None,
    inference_cache: Optional[Dict[str, Any]] = None,
) -> Tensor
Forward pass through the RG-LRU block.
hidden_states
torch.Tensor
Input tensor of shape (B, L, D) where:
  • B = batch size
  • L = sequence length
  • D = model dimension (d_model)
integration_timesteps
torch.Tensor | None
default:"None"
Currently unused. Kept for LTV interface compatibility.
lengths
torch.Tensor | None
default:"None"
Currently unused. Kept for interface compatibility.
inference_cache
Dict[str, Any] | None
default:"None"
Cache dict for autoregressive generation. If provided, must contain:
  • "conv_state": Convolution state tensor
  • "lrnn_state": RG-LRU state tensor
  • "seqlen_offset": Current position in sequence
output
torch.Tensor
Output tensor of shape (B, L, D).

step

def step(
    self,
    hidden_states: Tensor,
    inference_cache: Dict[str, Any],
    **kwargs,
) -> Tuple[Tensor, Dict[str, Any]]
Single recurrent step for autoregressive inference.
hidden_states
torch.Tensor
Input tensor of shape (B, 1, D) for single timestep.
inference_cache
Dict[str, Any]
Cache dictionary containing:
  • "conv_state": Convolution state, shape (B, D_inner, d_conv)
  • "lrnn_state": RG-LRU state, shape (B, D_inner, 1)
  • "seqlen_offset": Current position in sequence
**kwargs
Any
Additional keyword arguments (unused).
output
Tuple[torch.Tensor, Dict[str, Any]]
Tuple containing:
  • Output tensor of shape (B, 1, D)
  • Updated cache dictionary

allocate_inference_cache

def allocate_inference_cache(
    self,
    batch_size: int,
    max_seqlen: int,
    dtype: Optional[torch.dtype] = None,
    **kwargs,
) -> Dict[str, Any]
Allocate cache for autoregressive inference.
batch_size
int
Batch size for inference.
max_seqlen
int
Maximum sequence length. Unused but kept for interface consistency.
dtype
torch.dtype | None
default:"None"
Data type for cache tensors. If None, uses model’s parameter dtype.
**kwargs
Any
Additional keyword arguments (unused).
cache
Dict[str, Any]
Cache dictionary containing:
  • "conv_state": Shape (B, D_inner, d_conv)
  • "lrnn_state": Shape (B, D_inner, 1)
  • "seqlen_offset": Initialized to 0

Examples

Basic Usage

import torch
from lrnnx.models.ltv import RGLRU

# Create RGLRU model with default 1D state
model = RGLRU(d_model=64, d_conv=4)

# Forward pass
x = torch.randn(2, 128, 64)
y = model(x)
print(y.shape)  # torch.Size([2, 128, 64])

Custom Initialization

import torch
from lrnnx.models.ltv import RGLRU

# Create model with custom recurrence parameters
model = RGLRU(
    d_model=64,
    d_conv=4,
    expand=2,  # Increase capacity
    c=16.0,  # Increase gate scaling
    a_init_range=(0.95, 0.999),  # Favor longer-range dependencies
)

x = torch.randn(2, 128, 64)
y = model(x)

Autoregressive Inference

import torch
from lrnnx.models.ltv import RGLRU

model = RGLRU(d_model=64)
model.eval()

# Allocate cache
cache = model.allocate_inference_cache(
    batch_size=2,
    max_seqlen=100
)

# Generate sequence step by step
for t in range(100):
    x_t = get_next_input(t)  # Shape: (2, 1, 64)
    y_t, cache = model.step(x_t, cache)
    # Use y_t for next step...

Multi-Layer Stack

import torch
import torch.nn as nn
from lrnnx.models.ltv import RGLRU

class StackedRGLRU(nn.Module):
    def __init__(self, d_model, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            RGLRU(d_model=d_model, layer_idx=i)
            for i in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = StackedRGLRU(d_model=64, num_layers=4)
x = torch.randn(2, 128, 64)
y = model(x)

Architecture Details

Dual-Stream Processing

RGLRU processes input through two parallel streams:
  1. Gate stream: Linear → GeLU - produces multiplicative gating
  2. Recurrent stream: Linear → Conv1D → RG-LRU - performs temporal processing
The outputs are combined via element-wise multiplication before the final projection.

Recurrence Equation

The RG-LRU recurrence is:
recurrent_gate = σ(W_r * x)
input_gate = σ(W_i * x)
a_bar = a^(c * recurrent_gate)
h_t = a_bar * h_{t-1} + sqrt(1 - a_bar²) * (input_gate * x)
y_t = h_t
where:
  • a is the learnable base parameter in (0, 1)
  • c is the fixed scaling constant
  • σ is the sigmoid function

References

See Also

  • LTV_LRNN - Base class for LTV models
  • Mamba - Selective state space model
  • S7 - Simplified state space model

Build docs developers (and LLMs) love