Overview
TheLRNNLMHeadModel is a flexible language model architecture that combines Linear Recurrent Neural Networks (LRNNs) with a language modeling head for next-token prediction. It supports mixing different sequence model types across layers, enabling hybrid architectures that combine the strengths of various LRNN variants.
This architecture is inspired by the Mamba implementation and supports both Linear Time-Invariant (LTI) and Linear Time-Varying (LTV) models.
Architecture
The model consists of:- Token embedding layer: Maps input token IDs to dense vectors
- LRNN backbone: Stack of residual blocks, each containing:
- A mixer layer (LRNN or attention)
- Optional gated MLP for additional expressiveness
- Layer normalization (LayerNorm or RMSNorm)
- Residual connections
- Language modeling head: Linear projection to vocabulary size for next-token prediction
Class Signature
Parameters
Model dimension (hidden size)
State dimension for LRNN layers
Number of layers in the model
Size of the vocabulary
List of mixer type names for each layer. Must have length equal to
n_layer.Supported mixer types:"LRU"- Linear Recurrent Unit"S4"- Structured State Space (S4)"S4D"- Diagonal S4"S5"- Simplified S4"Centaurus"- Centaurus model"Mamba"- Mamba (inspired by Mamba paper)"RGLRU"- RG-LRU (from RG-LRU paper)"S7"- S7 model"attn"- Multi-head attention
["S5", "S5", "attn", "Mamba"] for a 4-layer hybrid modelIntermediate dimension for MLP layers. Set to 0 to disable MLP (single residual per layer).
Typical value is
4 * d_model for standard transformer-style FFN.Additional arguments for mixer layers. Can be:
- A dict mapping mixer type names to their kwargs:
{"S5": {"dt_min": 0.001}, "attn": {"num_heads": 8}} - A single dict applied to all mixers
MLP class to use (defaults to
GatedMLP)Epsilon value for layer normalization
Whether to use RMSNorm instead of LayerNorm
Whether to use fused add+norm operations (requires Triton kernels)
Whether to compute residuals in float32 for numerical stability
Whether to tie input and output embeddings (weight sharing)
Pad vocabulary size to multiple of this value for efficiency
Configuration for weight initialization
Device to place tensors on
Data type for tensors
Usage Example
Basic Language Model
Hybrid Architecture with Attention
Autoregressive Generation
Save and Load
Methods
forward
input_ids(torch.Tensor): Input token IDs of shape(B, L)inference_params(Dict, optional): Parameters for inference modenum_last_tokens(int): If > 0, only return logits for last n tokensintegration_timesteps(torch.Tensor, optional): Timesteps for LTV models (shape:(B, L))lengths(torch.Tensor, optional): Sequence lengths for variable-length sequences (shape:(B,))
- namedtuple with
logitsfield of shape(B, L, vocab_size)
step
input_ids(torch.Tensor): Input token IDs of shape(B, 1)— single tokencaches(Dict): Dictionary mapping layer indices to their cached statesintegration_timesteps(torch.Tensor, optional): Integration timesteps for LTV models
- namedtuple with
logitsfield of shape(B, 1, vocab_size)
allocate_inference_cache
References
- Implementation reference: Mamba
- Mamba: Linear-Time Sequence Modeling with Selective State Spaces
- Parallelizing Linear Recurrent Attention (RG-LRU)
