Overview
S7 (Selective and Simplified State Space Layers) is a selective state space model that makes all SSM matrices (A, B, C, D) input-dependent while using a simplified architecture. It combines the expressiveness of input-dependent dynamics with HiPPO initialization for strong long-range modeling.
S7 is designed to be simpler than Mamba while maintaining competitive selective processing capabilities.
Paper Reference
S7: Selective and Simplified State Space Layers for Sequence Modeling
https://arxiv.org/abs/2410.03464
Installation
from lrnnx.models.ltv import S7
Parameters
Model dimension - size of input and output features.
State dimension. Must be divisible by J. Typically 64-256.
Number of blocks for HiPPO initialization. d_state must be divisible by J.
Whether to use CUDA fast path if available for significant speedup.
Layer index for multi-layer models, used for caching during inference.
device
torch.device
default:"None"
Device for model parameters.
dtype
torch.dtype
default:"None"
Data type for model parameters.
Usage Example
Basic Usage
import torch
from lrnnx.models.ltv import S7
# Create S7 model
model = S7(d_model=64, d_state=64)
# Forward pass
x = torch.randn(2, 128, 64) # (batch, length, features)
y = model(x)
print(y.shape) # torch.Size([2, 128, 64])
With HiPPO Blocks
import torch
from lrnnx.models.ltv import S7
# Use multiple HiPPO blocks
model = S7(
d_model=128,
d_state=128, # Must be divisible by J
J=4, # 4 blocks of 32 states each
use_fast_path=True,
)
x = torch.randn(4, 256, 128)
y = model(x)
Autoregressive Inference
import torch
from lrnnx.models.ltv import S7
model = S7(d_model=256, d_state=64)
batch_size = 2
max_seqlen = 1024
# Allocate inference cache
cache = model.allocate_inference_cache(
batch_size=batch_size,
max_seqlen=max_seqlen
)
# Initialize offset
cache["seqlen_offset"] = 0
# Generate sequence step-by-step
for t in range(100):
x_t = torch.randn(batch_size, 1, 256)
y_t, cache = model.step(x_t, cache)
# y_t.shape: (batch_size, 1, 256)
Language Modeling Setup
import torch
import torch.nn as nn
from lrnnx.models.ltv import S7
class S7LanguageModel(nn.Module):
def __init__(self, vocab_size, d_model=512, d_state=128, n_layers=12):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
nn.ModuleDict({
's7': S7(d_model=d_model, d_state=d_state, J=4),
'norm': nn.LayerNorm(d_model),
'mlp': nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
),
})
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size)
def forward(self, x):
x = self.embedding(x)
for layer in self.layers:
# S7 already includes residual internally
x = layer['s7'](x)
x = layer['norm'](x)
x = layer['mlp'](x) + x
x = self.norm(x)
return self.head(x)
Key Features
Unlike Mamba (fixed A) or RG-LRU (fixed a), S7 makes everything input-dependent:
# All projected from input
A, B, C, D, bias = x_proj(x)
# A: (B, N, L) - per-timestep state dynamics
# B: (B, N, H, L) - input matrix
# C: (B, H, N, L) - output matrix
# D: (B, H, L) - skip connection
# bias: (B, N, L) - state bias
HiPPO Initialization
S7 uses HiPPO (High-order Polynomial Projection Operators) for A:
# Base HiPPO matrix (DPLR form)
base_params = make_DPLR_HiPPO(d_state // J)
# Repeated J times
base_params = base_params.repeat(J)
# Added to input-dependent A
A_effective = A_input_dependent + base_params
HiPPO provides strong inductive bias for long-range dependencies.
Simplified Discretization
S7 uses a custom discretization scheme:
# Compute A_bar
A_sq_half = A^2 + 0.5
A_bar = 1 - 1 / A_sq_half
# Update
state_{t+1} = A_bar * state_t + (B^T @ x + bias)
# Output
y_t = C @ state_t + D * x
This is simpler than ZOH or bilinear transforms.
Residual Connection Built-in
S7 includes a residual connection inside the model:
y = s7_forward(x)
out = y + x # Built into S7
Architecture Details
Forward Pass
-
Input Projection:
x = in_proj(x)
- Linear projection (no expansion)
-
Parameter Projection:
A, B, C, D, bias = x_proj(x)
- Project to all SSM parameters
- A: d_state
- B: d_model * d_state
- C: d_model * d_state
- D: d_model
- bias: d_state
-
Add HiPPO:
A = A + base_params
- Add learned HiPPO initialization
-
Selective Scan:
y = s7_scan(x, A, B, C, bias)
- Core SSM computation
- All matrices are input-dependent
-
Add D:
y = y + D * x
- Input-dependent skip connection
-
Gating:
y = sigmoid(gate_proj(gelu(y))) * y
-
Residual:
out = y + input
State Update
The S7 recurrence:
# Discretize A
A_sq_half = A^2 + 0.5
A_bar = 1 - 1 / A_sq_half
# Compute input projection
Bu = B^T @ x + bias
# Update state
state_{t+1} = A_bar * state_t + Bu
# Compute output
y_t = C @ state_{t+1} + D * x_t
State Representation
S7 maintains a real-valued state:
state: (batch_size, d_state) dtype=float32
Unlike LTI models (complex state), S7 uses real values.
HiPPO Initialization
What is HiPPO?
HiPPO (High-order Polynomial Projection Operators) provides matrices that:
- Compress history into fixed-size state
- Preserve polynomial features
- Enable long-range dependencies
S7 uses the DPLR (Diagonal Plus Low-Rank) HiPPO variant:
# Generate base matrix
base, _, _, _, _ = make_DPLR_HiPPO(d_state // J)
# Learned on top of HiPPO structure
A_effective = A_learned + base
Multiple Blocks (J > 1)
Using J > 1 creates multiple HiPPO blocks:
# J=4, d_state=128
# Creates 4 blocks of 32 states each
base = make_DPLR_HiPPO(32).repeat(4)
This can improve capacity for diverse patterns.
Expressiveness
✅ Highly expressive:
- All matrices input-dependent
- HiPPO initialization
- Self-gating mechanism
Speed
⚠️ Moderate:
- Slower than Mamba (more projections)
- Faster than full attention
- Benefits from CUDA kernels
Memory
⚠️ Moderate:
- More parameters than Mamba
- Stores all projected matrices
- Reasonable for most tasks
Use J=4 or J=8 for best results. This creates multiple HiPPO blocks that can specialize for different patterns.
S7 is more parameter-heavy than Mamba but can be more expressive due to input-dependent A, B, C, D.
Make sure d_state is divisible by J. For example:
d_state=64, J=4 ✅
d_state=64, J=5 ❌ (64 not divisible by 5)
When to Use S7
✅ Use S7 when:
- You want maximum selectivity (all matrices input-dependent)
- Long-range dependencies are critical (HiPPO)
- You need strong inductive bias
- Research/experimentation on selective SSMs
❌ Consider alternatives when:
- Speed is critical → RG-LRU
- Simplicity preferred → Mamba
- Parameter efficiency needed → RG-LRU
Comparison with Other Selective Models
| Model | Input-Dep A | Input-Dep B,C | HiPPO | Gating | Speed |
|---|
| S7 | ✅ | ✅ | ✅ | ✅ | ⭐⭐⭐ |
| Mamba | ❌ | ✅ | ❌ | ✅ | ⭐⭐⭐⭐ |
| RG-LRU | ❌ | ❌ (gates) | ❌ | ✅ | ⭐⭐⭐⭐⭐ |
S7 is the most selective but also the slowest.
Advanced Usage
Variable HiPPO Blocks
import torch
import torch.nn as nn
from lrnnx.models.ltv import S7
class MultiScaleS7(nn.Module):
def __init__(self, d_model=256):
super().__init__()
# Coarse scale (fewer blocks)
self.s7_coarse = S7(d_model=d_model, d_state=64, J=2)
# Fine scale (more blocks)
self.s7_fine = S7(d_model=d_model, d_state=128, J=8)
self.combine = nn.Linear(d_model * 2, d_model)
def forward(self, x):
y_coarse = self.s7_coarse(x)
y_fine = self.s7_fine(x)
return self.combine(torch.cat([y_coarse, y_fine], dim=-1))
Hybrid with Mamba
import torch.nn as nn
from lrnnx.models.ltv import S7, Mamba
class HybridModel(nn.Module):
def __init__(self, d_model=256):
super().__init__()
# S7 for input-dependent A
self.s7 = S7(d_model=d_model, d_state=128, J=4)
# Mamba for efficiency
self.mamba = Mamba(d_model=d_model, d_state=16)
def forward(self, x):
x = self.s7(x) # S7 already has residual
x = self.mamba(x) + x # Mamba residual
return x
See Also
- Mamba - Faster selective SSM
- RG-LRU - Simpler gated model
- S5 - Non-selective baseline