Skip to main content

Overview

S5 (Simplified State Space) is a clean, easy-to-understand implementation of state space models. It provides a straightforward SSM architecture with support for multiple discretization methods and optional conjugate symmetry. S5 is ideal for learning SSM concepts and serves as a strong baseline for sequence modeling tasks.

Paper Reference

Simplified State Space Layers for Sequence Modeling https://openreview.net/forum?id=Ai8Hw3AXqks

Installation

from lrnnx.models.lti import S5

Parameters

d_model
int
required
Model dimension - size of input and output features.
d_state
int
required
State dimension (P in the original paper). Internal dimension of the SSM state space.
discretization
Literal['zoh', 'bilinear', 'dirac', 'no_discretization']
required
Discretization method to use:
  • 'zoh': Zero-Order Hold (recommended)
  • 'bilinear': Bilinear transform
  • 'dirac': Dirac delta approximation
  • 'no_discretization': Skip discretization step
conj_sym
bool
default:"False"
If True, uses conjugate symmetry for the state space model. Currently not implemented.

Usage Example

Basic Usage

import torch
from lrnnx.models.lti import S5

# Create S5 model with ZOH discretization
model = S5(d_model=64, d_state=64, discretization="zoh")

# Forward pass
x = torch.randn(2, 128, 64)  # (batch, length, features)
y = model(x)

print(y.shape)  # torch.Size([2, 128, 64])

Different Discretization Methods

import torch
from lrnnx.models.lti import S5

# Zero-Order Hold (default, recommended)
model_zoh = S5(d_model=64, d_state=64, discretization="zoh")

# Bilinear transform (better frequency response)
model_bilinear = S5(d_model=64, d_state=64, discretization="bilinear")

# Dirac delta (simpler, faster)
model_dirac = S5(d_model=64, d_state=64, discretization="dirac")

x = torch.randn(2, 128, 64)
y_zoh = model_zoh(x)
y_bilinear = model_bilinear(x)
y_dirac = model_dirac(x)

Autoregressive Inference

import torch
from lrnnx.models.lti import S5

model = S5(d_model=64, d_state=64, discretization="zoh")
batch_size = 2

# Allocate inference cache
cache = model.allocate_inference_cache(batch_size=batch_size)

# Generate sequence step-by-step
outputs = []
for t in range(100):
    x_t = torch.randn(batch_size, 64)  # Single timestep input
    y_t, cache = model.step(x_t, cache)
    outputs.append(y_t)
    # y_t.shape: (batch_size, 64)

output_sequence = torch.stack(outputs, dim=1)  # (batch_size, 100, 64)

Language Modeling Setup

import torch
import torch.nn as nn
from lrnnx.models.lti import S5

class S5LanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, d_state=64, n_layers=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            S5(d_model=d_model, d_state=d_state, discretization="zoh")
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)  # (batch, seq_len, d_model)
        for layer in self.layers:
            x = layer(x) + x  # Residual connection
        x = self.norm(x)
        return self.head(x)

model = S5LanguageModel(vocab_size=50000)

Key Features

Simple Architecture

S5 has a minimal, easy-to-understand structure:
# Continuous-time SSM
h'(t) = A h(t) + B x(t)
y(t) = C h(t) + D x(t)

# Discretized
h_t = A_bar h_{t-1} + B_bar x_t
y_t = C h_t + D x_t

Initialization

S5 uses well-motivated initialization:
  • A matrix: Complex diagonal with:
    • Real part: log(-0.5) (via inverse softplus)
    • Imaginary part: π * [0, 1, 2, ..., N-1] (frequency spacing)
  • B matrix: Ones scaled by 1/√H
  • C matrix: Random Gaussian scaled by √(2/N)
  • D matrix: Random Gaussian scaled by √(2/H)
  • dt: Log-spaced from 0.001 to 0.1

FFT Convolution

Like S4, S5 uses FFT for efficient parallel training:
# Compute kernel
K = A_bar^[0, 1, ..., L-1]  # Powers of A_bar
y = FFT_conv(x, K @ B_bar, C)

Discretization Methods

Zero-Order Hold (ZOH)

The continuous input is held constant between timesteps:
A_bar = exp(dt * A)
B_bar = (A_bar - I) @ A^{-1} @ B
Use when: General purpose (recommended default)

Bilinear Transform

Tustin’s method using trapezoidal integration:
A_bar = (I + dt/2 * A) @ (I - dt/2 * A)^{-1}
B_bar = (I - dt/2 * A)^{-1} @ dt * B
Use when: Better frequency response preservation needed

Dirac Delta

Simplest discretization:
A_bar = I + dt * A
B_bar = dt * B
Use when: Speed is critical, approximate discretization acceptable

Architecture Details

Forward Pass

  1. Discretize: Convert continuous SSM to discrete-time
  2. Compute kernel: Generate convolution kernel from A_bar
  3. FFT convolution: Apply kernel to input efficiently
  4. Output projection: Add D skip connection

State Representation

S5 maintains a complex-valued state of dimension d_state:
state: (batch_size, d_state) dtype=complex64
The output is real-valued through:
y = (C @ state).real + D @ x

Performance Tips

Start with discretization="zoh" - it’s the most robust choice for general sequence modeling.
S5 is simpler than S4/S4D but equally effective for many tasks. It’s a great starting point for understanding SSMs.
The conj_sym parameter is not yet implemented. Leave it as False.

Comparison with Other Models

FeatureS5S4S4D
Complexity⭐ Simple⭐⭐⭐ Complex⭐⭐ Medium
Speed⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
ParametersFewerMoreMedium
Code clarity⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
Performance⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐

When to Use S5

Use S5 when:
  • Learning about SSMs
  • You want clean, readable code
  • You need a strong baseline
  • Simplicity is valued
Consider alternatives when:
  • You need maximum performance → S4D
  • You want input-dependent dynamics → Mamba
  • You need minimal parameters → LRU

See Also

  • S4 - Original structured SSM with DPLR
  • S4D - Diagonal variant
  • LRU - Minimal diagonal SSM

Build docs developers (and LLMs) love