Skip to main content

Overview

S7 (Selective and Simplified State Space Layer) is an LTV model that combines input-dependent state space parameters with a simplified architecture. It uses HiPPO initialization for the state transition matrix and features time-varying A, B, C, and D parameters. Key features:
  • Fully time-varying parameters (A, B, C, D all depend on input)
  • HiPPO-based initialization for stable long-range dependencies
  • No convolution layer (simpler than Mamba/RGLRU)
  • Gated output with residual connection
  • Custom discretization scheme

Import

from lrnnx.models.ltv import S7

Class Signature

class S7(LTV_LRNN)

Constructor

__init__

def __init__(
    self,
    d_model: int,
    d_state: int,
    J: int = 1,
    use_fast_path: bool = True,
    layer_idx: Optional[int] = None,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
)
Initialize S7 model.
d_model
int
required
Model dimension (input/output dimension).
d_state
int
required
State dimension. Must be divisible by J.
J
int
default:"1"
Number of blocks for HiPPO initialization. The state space is divided into J blocks, each initialized with HiPPO parameters.
use_fast_path
bool
default:"True"
Whether to use the CUDA fast path if available. Enables fused kernel implementation for better performance.
layer_idx
int | None
default:"None"
Layer index for multi-layer models, used for caching in stacked architectures.
device
torch.device | None
default:"None"
Device for the model parameters. If None, uses default device.
dtype
torch.dtype | None
default:"None"
Data type for the model parameters. If None, uses default dtype.

Methods

forward

def forward(
    self,
    hidden_states: Tensor,
    integration_timesteps: Optional[Tensor] = None,
    lengths: Optional[Tensor] = None,
    inference_cache: Optional[Dict[str, Any]] = None,
) -> Tensor
Forward pass through the S7 layer.
hidden_states
torch.Tensor
Input tensor of shape (B, L, H) where:
  • B = batch size
  • L = sequence length
  • H = model dimension (d_model)
integration_timesteps
torch.Tensor | None
default:"None"
Timesteps for async/event-driven discretization. Shape (B, L). Currently unused but kept for interface compatibility.
lengths
torch.Tensor | None
default:"None"
Lengths of sequences for variable-length batches. Shape (B,). Currently unused.
inference_cache
Dict[str, Any] | None
default:"None"
Cache for autoregressive generation. If provided, must contain:
  • "lrnn_state": S7 state tensor
  • "seqlen_offset": Current position in sequence
output
torch.Tensor
Output tensor of shape (B, L, H).

step

def step(
    self,
    hidden_states: Tensor,
    inference_cache: Dict[str, Any],
    **kwargs,
) -> Tuple[Tensor, Dict[str, Any]]
Performs a single recurrent step of S7 for autoregressive inference.
hidden_states
torch.Tensor
Input at current timestep, shape (B, 1, H).
inference_cache
Dict[str, Any]
Cache dictionary containing:
  • "lrnn_state": S7 state, shape (B, N) where N = d_state
  • "seqlen_offset": Current position in sequence
**kwargs
Any
Additional keyword arguments (unused).
output
Tuple[torch.Tensor, Dict[str, Any]]
A tuple containing:
  • Output tensor at the current timestep, shape (B, 1, H)
  • Updated cache dictionary

allocate_inference_cache

def allocate_inference_cache(
    self,
    batch_size: int,
    max_seqlen: int,
    dtype: Optional[torch.dtype] = None,
    **kwargs,
) -> Dict[str, Any]
Allocates cache for S7 autoregressive inference.
batch_size
int
The batch size for inference.
max_seqlen
int
Maximum sequence length. Unused but kept for interface consistency.
dtype
torch.dtype | None
default:"None"
Data type for allocated tensors. If None, uses model’s parameter dtype.
**kwargs
Any
Additional keyword arguments (unused).
cache
Dict[str, Any]
Cache dictionary containing:
  • "lrnn_state": Zero-initialized state, shape (B, N)
  • "seqlen_offset": Position counter, initialized to 0

Examples

Basic Usage

import torch
from lrnnx.models.ltv import S7

# Create S7 model
model = S7(d_model=64, d_state=64)

# Forward pass
x = torch.randn(2, 128, 64)
y = model(x)
print(y.shape)  # torch.Size([2, 128, 64])

Multi-Block Initialization

import torch
from lrnnx.models.ltv import S7

# Use 4 HiPPO blocks for initialization
model = S7(
    d_model=64,
    d_state=64,  # Must be divisible by J
    J=4  # 4 blocks of 16 dimensions each
)

x = torch.randn(2, 128, 64)
y = model(x)

Autoregressive Inference

import torch
from lrnnx.models.ltv import S7

model = S7(d_model=64, d_state=64)
model.eval()

# Allocate cache
cache = model.allocate_inference_cache(
    batch_size=2,
    max_seqlen=100
)

# Generate sequence step by step
outputs = []
for t in range(100):
    x_t = get_next_input(t)  # Shape: (2, 1, 64)
    y_t, cache = model.step(x_t, cache)
    outputs.append(y_t)

output = torch.cat(outputs, dim=1)  # Shape: (2, 100, 64)

Large State Space

import torch
from lrnnx.models.ltv import S7

# S7 can handle large state dimensions efficiently
model = S7(
    d_model=256,
    d_state=256,  # Large state space
    J=8,
    use_fast_path=True  # Use CUDA kernels for performance
)

x = torch.randn(4, 512, 256)
y = model(x)

Architecture Details

Time-Varying Parameters

S7 computes all SSM parameters from the input:
# From input x, compute:
A, B, C, D, bias = x_proj(x)  # All time-varying

# A is augmented with learned HiPPO base:
A = A + base_params

# Discretization:
A_bar = 1 - 1/(A² + 0.5)

# Recurrence:
h_t = A_bar * h_{t-1} + (B^T @ x + bias)
y_t = C @ h_t + D * x

Residual and Gating

The output includes gating and a residual connection:
gate = sigmoid(gate_proj(gelu(y)))
output = gate * y + input  # Residual connection
This stabilizes training and improves gradient flow.

HiPPO Initialization

S7 uses HiPPO (High-order Polynomial Projection Operators) initialization for the base transition matrix, which provides:
  • Stable long-range dependencies
  • Theoretically-grounded initialization
  • Better out-of-the-box performance

State Space Dimensions

Unlike Mamba (typically d_state=16) or RGLRU (d_state=1), S7 commonly uses larger state dimensions (e.g., 64-256) to increase model capacity. The state dimension should be divisible by the number of blocks J.

References

See Also

  • LTV_LRNN - Base class for LTV models
  • Mamba - Selective state space model with convolution
  • RGLRU - Gated linear recurrent unit

Build docs developers (and LLMs) love