Skip to main content

Overview

S4 (Structured State Space Sequence model) is a foundational LTI model that uses a DPLR (Diagonal Plus Low-Rank) parameterization for efficient computation. It employs complex diagonal matrices with low-rank corrections to capture long-range dependencies while maintaining computational efficiency. S4 uses FFT-based convolution for parallel training and supports efficient autoregressive inference through recurrent mode.

Paper Reference

Efficiently Modeling Long Sequences with Structured State Spaces Original implementation: https://github.com/state-spaces/s4

Installation

from lrnnx.models.lti import S4

Parameters

d_model
int
required
Model dimension - size of input and output features.
d_state
int
default:"64"
State dimension (N). Internal dimension of the SSM state space. Higher values increase capacity but also computation.
l_max
int
default:"None"
Maximum sequence length for the kernel. Must be specified for FFT convolution mode.
channels
int
default:"1"
Number of channels/heads. Allows multi-headed SSM processing.
bottleneck
int
default:"None"
Reduce dimension of inner layer (e.g., used in GSS). If specified, adds an input linear projection.
gate
int
default:"None"
Add multiplicative gating (e.g., used in GSS). Creates a gated pathway for enhanced expressiveness.
final_act
str
default:"'glu'"
Activation after final linear layer. Options: 'glu', 'id' (no activation), or None (no linear layer).
dropout
float
default:"0.0"
Standard dropout probability.
tie_dropout
bool
default:"False"
Tie dropout mask across sequence length, emulating nn.Dropout1d.
transposed
bool
default:"True"
Backbone axis ordering: (B, H, L) if True, (B, L, H) if False.

SSM Configuration

dt_min
float
default:"0.001"
Minimum value for dt (timestep) initialization.
dt_max
float
default:"0.1"
Maximum value for dt initialization.
dt_tie
bool
default:"True"
Tie dt across channels - uses same timestep for all channels.
dt_transform
str
default:"'exp'"
Transformation to apply to dt. Options: 'exp', 'softplus', etc.
dt_fast
bool
default:"False"
Fast dt initialization mode.
rank
int
default:"1"
Rank of the low-rank correction for DPLR parameterization.
n_ssm
int
default:"None"
Number of independent SSMs. Defaults to d_model if not specified.
init
str
default:"'legs'"
Initialization method for the A matrix. Options: 'legs', 'lin', 'hippo', etc.
deterministic
bool
default:"False"
Use deterministic initialization for reproducibility.
real_transform
str
default:"'exp'"
Transformation for the real part of A matrix.
imag_transform
str
default:"'none'"
Transformation for the imaginary part of A matrix.
is_real
bool
default:"False"
Whether to use real-valued SSMs (instead of complex).
lr
float
default:"None"
Specific learning rate for SSM parameters. Useful for differential learning rates.
wd
float
default:"0.0"
Specific weight decay for SSM parameters.
verbose
bool
default:"True"
Print initialization information during setup.

Usage Example

Basic Usage

import torch
from lrnnx.models.lti import S4

# Create S4 model
model = S4(d_model=64, d_state=64, l_max=1024)

# Forward pass
x = torch.randn(2, 1024, 64)  # (batch, length, features)
y, state = model(x)

print(y.shape)  # torch.Size([2, 1024, 64])

Autoregressive Inference

import torch
from lrnnx.models.lti import S4

model = S4(d_model=64, d_state=64, l_max=1024)
batch_size = 2

# Allocate inference cache
cache = model.allocate_inference_cache(batch_size=batch_size)

# Process sequence step-by-step
for t in range(100):
    x_t = torch.randn(batch_size, 64)  # Single timestep
    y_t, cache = model.step(x_t, cache)
    # y_t shape: (batch_size, 64)

With Gating and Bottleneck

model = S4(
    d_model=256,
    d_state=64,
    l_max=2048,
    bottleneck=2,      # Reduce to 128 dims internally
    gate=2,            # Add gating mechanism
    dropout=0.1,
    channels=4,        # Multi-headed
)

x = torch.randn(4, 2048, 256)
y, state = model(x)

Key Features

DPLR Parameterization

S4 uses a Diagonal Plus Low-Rank decomposition:
A = Λ - PP*
Where:
  • Λ is a complex diagonal matrix
  • P is a low-rank correction matrix
  • This enables O(N) computation instead of O(N²)

FFT Convolution

For training, S4 computes the full convolution kernel and uses FFT:
# Kernel computation
K = (C @ (A^0, A^1, ..., A^L) @ B)
# FFT convolution
y = IFFT(FFT(K) * FFT(x))

Recurrent Mode

For inference, S4 uses recurrent updates:
h_t = A @ h_{t-1} + B @ x_t
y_t = C @ h_t + D @ x_t

Architecture Details

Forward Pass Structure

  1. Optional bottleneck projection: Reduce input dimension
  2. Optional input gating: Multiplicative gating pathway
  3. SSM convolution: Core state space computation
  4. D skip connection: Direct input-output connection
  5. Activation: GELU activation
  6. Optional output gating: Gate the SSM output
  7. Output projection: Final linear layer with optional GLU

Initialization

S4 uses specialized initialization:
  • A matrix: HiPPO or LEGS initialization for long-range memory
  • B, C matrices: Random initialization scaled by dimensions
  • dt: Log-uniform spacing between dt_min and dt_max
  • D: Random initialization

Performance Tips

Set l_max to the maximum sequence length you’ll encounter during training for optimal FFT performance.
S4 is an LTI model and does not support variable timesteps or async discretization. For event-driven data, use Mamba (LTV) instead.
The rank parameter controls the expressiveness of the low-rank correction. Higher rank increases capacity but also computation. Rank=1 is usually sufficient.

See Also

  • S4D - Diagonal variant with simpler parameterization
  • S5 - Simplified implementation with clearer code
  • Mamba - Input-dependent selective variant

Build docs developers (and LLMs) love