Skip to main content

Overview

The LRNNLMHeadModel is a flexible language model architecture that combines Linear Recurrent Neural Networks (LRNNs) with a language modeling head for next-token prediction. It supports mixing different sequence model types across layers, enabling hybrid architectures that combine the strengths of various LRNN variants. This architecture is inspired by the Mamba implementation and supports both Linear Time-Invariant (LTI) and Linear Time-Varying (LTV) models.

Architecture

The model consists of:
  • Token embedding layer: Maps input token IDs to dense vectors
  • LRNN backbone: Stack of residual blocks, each containing:
    • A mixer layer (LRNN or attention)
    • Optional gated MLP for additional expressiveness
    • Layer normalization (LayerNorm or RMSNorm)
    • Residual connections
  • Language modeling head: Linear projection to vocabulary size for next-token prediction

Class Signature

LRNNLMHeadModel(
    d_model: int,
    d_state: int,
    n_layer: int,
    vocab_size: int,
    mixer_types: list,
    d_intermediate: int = 0,
    mixer_kwargs: Optional[Dict] = None,
    mlp_cls = None,
    norm_epsilon: float = 1e-5,
    rms_norm: bool = True,
    fused_add_norm: bool = True,
    residual_in_fp32: bool = False,
    tie_embeddings: bool = True,
    pad_vocab_size_multiple: int = 8,
    initializer_cfg: Optional[Dict[str, Any]] = None,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
)

Parameters

d_model
int
required
Model dimension (hidden size)
d_state
int
required
State dimension for LRNN layers
n_layer
int
required
Number of layers in the model
vocab_size
int
required
Size of the vocabulary
mixer_types
list
required
List of mixer type names for each layer. Must have length equal to n_layer.Supported mixer types:
  • "LRU" - Linear Recurrent Unit
  • "S4" - Structured State Space (S4)
  • "S4D" - Diagonal S4
  • "S5" - Simplified S4
  • "Centaurus" - Centaurus model
  • "Mamba" - Mamba (inspired by Mamba paper)
  • "RGLRU" - RG-LRU (from RG-LRU paper)
  • "S7" - S7 model
  • "attn" - Multi-head attention
Example: ["S5", "S5", "attn", "Mamba"] for a 4-layer hybrid model
d_intermediate
int
default:"0"
Intermediate dimension for MLP layers. Set to 0 to disable MLP (single residual per layer). Typical value is 4 * d_model for standard transformer-style FFN.
mixer_kwargs
dict
default:"None"
Additional arguments for mixer layers. Can be:
  • A dict mapping mixer type names to their kwargs: {"S5": {"dt_min": 0.001}, "attn": {"num_heads": 8}}
  • A single dict applied to all mixers
mlp_cls
type
default:"None"
MLP class to use (defaults to GatedMLP)
norm_epsilon
float
default:"1e-5"
Epsilon value for layer normalization
rms_norm
bool
default:"True"
Whether to use RMSNorm instead of LayerNorm
fused_add_norm
bool
default:"True"
Whether to use fused add+norm operations (requires Triton kernels)
residual_in_fp32
bool
default:"False"
Whether to compute residuals in float32 for numerical stability
tie_embeddings
bool
default:"True"
Whether to tie input and output embeddings (weight sharing)
pad_vocab_size_multiple
int
default:"8"
Pad vocabulary size to multiple of this value for efficiency
initializer_cfg
dict
default:"None"
Configuration for weight initialization
device
torch.device
default:"None"
Device to place tensors on
dtype
torch.dtype
default:"None"
Data type for tensors

Usage Example

Basic Language Model

import torch
from lrnnx.architectures.language_model import LRNNLMHeadModel

# Create a 4-layer hybrid model with S5 and Mamba layers
model = LRNNLMHeadModel(
    d_model=512,
    d_state=64,
    n_layer=4,
    vocab_size=50257,  # GPT-2 vocabulary
    mixer_types=["S5", "S5", "Mamba", "Mamba"],
    d_intermediate=2048,  # 4 * d_model
    rms_norm=True,
    tie_embeddings=True,
)

# Forward pass
input_ids = torch.randint(0, 50257, (2, 128))  # (batch_size, seq_len)
output = model(input_ids)
logits = output.logits  # (2, 128, 50257)

Hybrid Architecture with Attention

# Mix different LRNN types with attention
model = LRNNLMHeadModel(
    d_model=768,
    d_state=128,
    n_layer=6,
    vocab_size=50257,
    mixer_types=["S5", "S5", "attn", "Mamba", "Mamba", "attn"],
    d_intermediate=3072,
    mixer_kwargs={
        "S5": {"dt_min": 0.001},
        "attn": {"num_heads": 12},
    },
)

Autoregressive Generation

# Allocate cache for efficient generation
batch_size = 1
max_seqlen = 512
caches = model.allocate_inference_cache(
    batch_size=batch_size,
    max_seqlen=max_seqlen,
    dtype=torch.float32,
)

# Generate tokens one at a time
input_ids = torch.tensor([[50256]])  # Start token
for _ in range(20):
    output = model.step(input_ids, caches)
    next_token = output.logits.argmax(dim=-1)
    input_ids = next_token

Save and Load

# Save model
model.save_pretrained("./my_lrnn_lm")

# Load model
loaded_model = LRNNLMHeadModel.from_pretrained(
    "./my_lrnn_lm",
    device=torch.device("cuda"),
)

Methods

forward

forward(
    input_ids: Tensor,
    position_ids: Optional[Tensor] = None,
    inference_params: Optional[Dict] = None,
    num_last_tokens: int = 0,
    integration_timesteps: Optional[Tensor] = None,
    lengths: Optional[Tensor] = None,
    **mixer_kwargs,
) -> namedtuple
Forward pass of the language model. Arguments:
  • input_ids (torch.Tensor): Input token IDs of shape (B, L)
  • inference_params (Dict, optional): Parameters for inference mode
  • num_last_tokens (int): If > 0, only return logits for last n tokens
  • integration_timesteps (torch.Tensor, optional): Timesteps for LTV models (shape: (B, L))
  • lengths (torch.Tensor, optional): Sequence lengths for variable-length sequences (shape: (B,))
Returns:
  • namedtuple with logits field of shape (B, L, vocab_size)

step

step(
    input_ids: Tensor,
    caches: Dict,
    integration_timesteps: Optional[Tensor] = None,
) -> namedtuple
Single-step inference for autoregressive generation. Arguments:
  • input_ids (torch.Tensor): Input token IDs of shape (B, 1) — single token
  • caches (Dict): Dictionary mapping layer indices to their cached states
  • integration_timesteps (torch.Tensor, optional): Integration timesteps for LTV models
Returns:
  • namedtuple with logits field of shape (B, 1, vocab_size)

allocate_inference_cache

allocate_inference_cache(
    batch_size: int,
    max_seqlen: int,
    dtype: Optional[torch.dtype] = None,
    **kwargs,
) -> Dict
Allocate inference cache for autoregressive generation.

References

See Also

  • LRU - Linear Recurrent Unit
  • S5 - Simplified S4
  • Mamba - Mamba model

Build docs developers (and LLMs) love