Skip to main content

Overview

S7 (Selective and Simplified State Space Layers) is a selective state space model that makes all SSM matrices (A, B, C, D) input-dependent while using a simplified architecture. It combines the expressiveness of input-dependent dynamics with HiPPO initialization for strong long-range modeling. S7 is designed to be simpler than Mamba while maintaining competitive selective processing capabilities.

Paper Reference

S7: Selective and Simplified State Space Layers for Sequence Modeling https://arxiv.org/abs/2410.03464

Installation

from lrnnx.models.ltv import S7

Parameters

d_model
int
required
Model dimension - size of input and output features.
d_state
int
required
State dimension. Must be divisible by J. Typically 64-256.
J
int
default:"1"
Number of blocks for HiPPO initialization. d_state must be divisible by J.
use_fast_path
bool
default:"True"
Whether to use CUDA fast path if available for significant speedup.
layer_idx
int
default:"None"
Layer index for multi-layer models, used for caching during inference.
device
torch.device
default:"None"
Device for model parameters.
dtype
torch.dtype
default:"None"
Data type for model parameters.

Usage Example

Basic Usage

import torch
from lrnnx.models.ltv import S7

# Create S7 model
model = S7(d_model=64, d_state=64)

# Forward pass
x = torch.randn(2, 128, 64)  # (batch, length, features)
y = model(x)

print(y.shape)  # torch.Size([2, 128, 64])

With HiPPO Blocks

import torch
from lrnnx.models.ltv import S7

# Use multiple HiPPO blocks
model = S7(
    d_model=128,
    d_state=128,  # Must be divisible by J
    J=4,  # 4 blocks of 32 states each
    use_fast_path=True,
)

x = torch.randn(4, 256, 128)
y = model(x)

Autoregressive Inference

import torch
from lrnnx.models.ltv import S7

model = S7(d_model=256, d_state=64)
batch_size = 2
max_seqlen = 1024

# Allocate inference cache
cache = model.allocate_inference_cache(
    batch_size=batch_size,
    max_seqlen=max_seqlen
)

# Initialize offset
cache["seqlen_offset"] = 0

# Generate sequence step-by-step
for t in range(100):
    x_t = torch.randn(batch_size, 1, 256)
    y_t, cache = model.step(x_t, cache)
    # y_t.shape: (batch_size, 1, 256)

Language Modeling Setup

import torch
import torch.nn as nn
from lrnnx.models.ltv import S7

class S7LanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, d_state=128, n_layers=12):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                's7': S7(d_model=d_model, d_state=d_state, J=4),
                'norm': nn.LayerNorm(d_model),
                'mlp': nn.Sequential(
                    nn.Linear(d_model, d_model * 4),
                    nn.GELU(),
                    nn.Linear(d_model * 4, d_model),
                ),
            })
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            # S7 already includes residual internally
            x = layer['s7'](x)
            x = layer['norm'](x)
            x = layer['mlp'](x) + x
        x = self.norm(x)
        return self.head(x)

Key Features

All Input-Dependent Matrices

Unlike Mamba (fixed A) or RG-LRU (fixed a), S7 makes everything input-dependent:
# All projected from input
A, B, C, D, bias = x_proj(x)

# A: (B, N, L) - per-timestep state dynamics
# B: (B, N, H, L) - input matrix
# C: (B, H, N, L) - output matrix  
# D: (B, H, L) - skip connection
# bias: (B, N, L) - state bias

HiPPO Initialization

S7 uses HiPPO (High-order Polynomial Projection Operators) for A:
# Base HiPPO matrix (DPLR form)
base_params = make_DPLR_HiPPO(d_state // J)

# Repeated J times
base_params = base_params.repeat(J)

# Added to input-dependent A
A_effective = A_input_dependent + base_params
HiPPO provides strong inductive bias for long-range dependencies.

Simplified Discretization

S7 uses a custom discretization scheme:
# Compute A_bar
A_sq_half = A^2 + 0.5
A_bar = 1 - 1 / A_sq_half

# Update
state_{t+1} = A_bar * state_t + (B^T @ x + bias)

# Output
y_t = C @ state_t + D * x
This is simpler than ZOH or bilinear transforms.

Residual Connection Built-in

S7 includes a residual connection inside the model:
y = s7_forward(x)
out = y + x  # Built into S7

Architecture Details

Forward Pass

  1. Input Projection: x = in_proj(x)
    • Linear projection (no expansion)
  2. Parameter Projection: A, B, C, D, bias = x_proj(x)
    • Project to all SSM parameters
    • A: d_state
    • B: d_model * d_state
    • C: d_model * d_state
    • D: d_model
    • bias: d_state
  3. Add HiPPO: A = A + base_params
    • Add learned HiPPO initialization
  4. Selective Scan: y = s7_scan(x, A, B, C, bias)
    • Core SSM computation
    • All matrices are input-dependent
  5. Add D: y = y + D * x
    • Input-dependent skip connection
  6. Gating: y = sigmoid(gate_proj(gelu(y))) * y
    • Self-gating mechanism
  7. Residual: out = y + input
    • Add original input

State Update

The S7 recurrence:
# Discretize A
A_sq_half = A^2 + 0.5
A_bar = 1 - 1 / A_sq_half

# Compute input projection
Bu = B^T @ x + bias

# Update state
state_{t+1} = A_bar * state_t + Bu

# Compute output
y_t = C @ state_{t+1} + D * x_t

State Representation

S7 maintains a real-valued state:
state: (batch_size, d_state) dtype=float32
Unlike LTI models (complex state), S7 uses real values.

HiPPO Initialization

What is HiPPO?

HiPPO (High-order Polynomial Projection Operators) provides matrices that:
  • Compress history into fixed-size state
  • Preserve polynomial features
  • Enable long-range dependencies

DPLR Form

S7 uses the DPLR (Diagonal Plus Low-Rank) HiPPO variant:
# Generate base matrix
base, _, _, _, _ = make_DPLR_HiPPO(d_state // J)

# Learned on top of HiPPO structure
A_effective = A_learned + base

Multiple Blocks (J > 1)

Using J > 1 creates multiple HiPPO blocks:
# J=4, d_state=128
# Creates 4 blocks of 32 states each
base = make_DPLR_HiPPO(32).repeat(4)
This can improve capacity for diverse patterns.

Performance Characteristics

Expressiveness

Highly expressive:
  • All matrices input-dependent
  • HiPPO initialization
  • Self-gating mechanism

Speed

⚠️ Moderate:
  • Slower than Mamba (more projections)
  • Faster than full attention
  • Benefits from CUDA kernels

Memory

⚠️ Moderate:
  • More parameters than Mamba
  • Stores all projected matrices
  • Reasonable for most tasks

Performance Tips

Use J=4 or J=8 for best results. This creates multiple HiPPO blocks that can specialize for different patterns.
S7 is more parameter-heavy than Mamba but can be more expressive due to input-dependent A, B, C, D.
Make sure d_state is divisible by J. For example:
  • d_state=64, J=4
  • d_state=64, J=5 ❌ (64 not divisible by 5)

When to Use S7

Use S7 when:
  • You want maximum selectivity (all matrices input-dependent)
  • Long-range dependencies are critical (HiPPO)
  • You need strong inductive bias
  • Research/experimentation on selective SSMs
Consider alternatives when:
  • Speed is critical → RG-LRU
  • Simplicity preferred → Mamba
  • Parameter efficiency needed → RG-LRU

Comparison with Other Selective Models

ModelInput-Dep AInput-Dep B,CHiPPOGatingSpeed
S7⭐⭐⭐
Mamba⭐⭐⭐⭐
RG-LRU❌ (gates)⭐⭐⭐⭐⭐
S7 is the most selective but also the slowest.

Advanced Usage

Variable HiPPO Blocks

import torch
import torch.nn as nn
from lrnnx.models.ltv import S7

class MultiScaleS7(nn.Module):
    def __init__(self, d_model=256):
        super().__init__()
        # Coarse scale (fewer blocks)
        self.s7_coarse = S7(d_model=d_model, d_state=64, J=2)
        # Fine scale (more blocks)
        self.s7_fine = S7(d_model=d_model, d_state=128, J=8)
        self.combine = nn.Linear(d_model * 2, d_model)
    
    def forward(self, x):
        y_coarse = self.s7_coarse(x)
        y_fine = self.s7_fine(x)
        return self.combine(torch.cat([y_coarse, y_fine], dim=-1))

Hybrid with Mamba

import torch.nn as nn
from lrnnx.models.ltv import S7, Mamba

class HybridModel(nn.Module):
    def __init__(self, d_model=256):
        super().__init__()
        # S7 for input-dependent A
        self.s7 = S7(d_model=d_model, d_state=128, J=4)
        # Mamba for efficiency
        self.mamba = Mamba(d_model=d_model, d_state=16)
    
    def forward(self, x):
        x = self.s7(x)  # S7 already has residual
        x = self.mamba(x) + x  # Mamba residual
        return x

See Also

  • Mamba - Faster selective SSM
  • RG-LRU - Simpler gated model
  • S5 - Non-selective baseline

Build docs developers (and LLMs) love