Skip to main content

Overview

The simplified scan operations provide CUDA-accelerated implementations for S5-style structured state space models. These operations use a diagonal SSM formulation with complex-valued inputs and projection matrices for efficient computation.

simplified_scan_fn

from lrnnx.ops.simplified_scan import simplified_scan_fn
Simplified SSM scan using CUDA kernel. This S5-style scan uses projection matrices B and C to map between input/output space and state space. Architecture: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L)
u
torch.Tensor
required
Complex input tensor of shape (batch, H, seqlen), dtype=complex64, where H is the hidden/input dimension
delta
torch.Tensor
required
Real timestep tensor of shape (batch, P, seqlen), dtype=float32, where P is the state dimension
A
torch.Tensor
required
Complex state matrix eigenvalues of shape (P,) or (P, 1), dtype=complex64. These are the diagonal elements of the state transition matrix
B
torch.Tensor
required
Complex projection matrix of shape (P, H), dtype=complex64. Projects input from H-dimensional space to P-dimensional state space
C
torch.Tensor
required
Complex projection matrix of shape (H, P), dtype=complex64. Projects state from P-dimensional space back to H-dimensional output
deltaA
torch.Tensor | None
default:"None"
Optional separate timestep for A discretization of shape (batch, P, seqlen), dtype=float32. If provided, A is discretized using deltaA while B uses delta, enabling asynchronous state dynamics
return_last_state
bool
default:"False"
Whether to return the last hidden state. If True, returns tuple (output, last_state) where last_state has shape (batch, P), dtype=complex64
discretization
str
default:"'bilinear'"
Discretization method. Options:
  • "bilinear": Bilinear transform (Tustin’s method) - recommended for S5
  • "zoh": Zero-order hold
  • "dirac": Dirac delta (no delta scaling)
output
torch.Tensor | tuple[torch.Tensor, torch.Tensor]
Complex output tensor of shape (batch, H, seqlen), dtype=complex64. If return_last_state=True, returns tuple (output, last_state) where last_state has shape (batch, P), dtype=complex64

Example

import torch
from lrnnx.ops.simplified_scan import simplified_scan_fn

batch, H, P, seqlen = 2, 64, 32, 128

# Create complex input and projection matrices
u = torch.randn(batch, H, seqlen, dtype=torch.complex64, device='cuda')
delta = torch.rand(batch, P, seqlen, device='cuda')
A = -torch.rand(P, dtype=torch.complex64, device='cuda')  # Negative real for stability
B = torch.randn(P, H, dtype=torch.complex64, device='cuda')
C = torch.randn(H, P, dtype=torch.complex64, device='cuda')

# Run simplified scan
output = simplified_scan_fn(
    u, delta, A, B, C,
    discretization="bilinear"
)

print(output.shape)  # (2, 64, 128)
print(output.dtype)  # torch.complex64

Discretization Methods

The discretization methods transform continuous-time state space parameters to discrete-time: Bilinear (recommended for S5):
A_bar = (1 + 0.5*deltaA*A) / (1 - 0.5*deltaA*A)
B_bar = delta / (1 - 0.5*delta*A) * B
Zero-Order Hold (ZOH):
A_bar = exp(deltaA * A)
B_bar = (exp(delta * A) - 1) / A * B
Dirac:
A_bar = exp(deltaA * A)
B_bar = B  # No delta scaling

s5_inner_fn

from lrnnx.ops.simplified_scan import s5_inner_fn
Complete S5 inner function using CUDA kernel. This wraps the simplified scan and adds conjugate symmetry handling and skip connections. Forward pass:
  1. SSM scan: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x
  2. Conjugate symmetry: y_real = (2 if conj_sym else 1) * Re(y)
  3. Skip connection: out = y_real + D * u.real
u
torch.Tensor
required
Complex input tensor of shape (batch, H, seqlen), dtype=complex64
delta
torch.Tensor
required
Real timestep tensor of shape (batch, P, seqlen), dtype=float32
A
torch.Tensor
required
Complex eigenvalues tensor of shape (P,) or (P, 1), dtype=complex64
B
torch.Tensor
required
Complex projection matrix of shape (P, H), dtype=complex64
C
torch.Tensor
required
Complex projection matrix of shape (H, P), dtype=complex64
D
torch.Tensor
required
Real skip connection tensor of shape (H,), dtype=float32, providing direct input-to-output connections
deltaA
torch.Tensor | None
default:"None"
Optional separate timestep for A discretization of shape (batch, P, seqlen), dtype=float32
discretization
str
default:"'bilinear'"
Discretization method: "bilinear", "zoh", or "dirac"
conj_sym
bool
default:"True"
If True, output is 2 * Re(y), leveraging conjugate symmetry. If False, output is Re(y). S5 models typically use conjugate symmetry for efficiency
output
torch.Tensor
Real output tensor of shape (batch, H, seqlen), dtype=float32

Example

import torch
from lrnnx.ops.simplified_scan import s5_inner_fn

batch, H, P, seqlen = 2, 64, 32, 128

# Create input tensors
u = torch.randn(batch, H, seqlen, dtype=torch.complex64, device='cuda')
delta = torch.rand(batch, P, seqlen, device='cuda')
A = -torch.rand(P, dtype=torch.complex64, device='cuda')
B = torch.randn(P, H, dtype=torch.complex64, device='cuda')
C = torch.randn(H, P, dtype=torch.complex64, device='cuda')
D = torch.randn(H, device='cuda')

# Run S5 inner function
output = s5_inner_fn(
    u, delta, A, B, C, D,
    discretization="bilinear",
    conj_sym=True
)

print(output.shape)  # (2, 64, 128)
print(output.dtype)  # torch.float32

Conjugate Symmetry

When conj_sym=True, the output uses conjugate symmetry:
  • Only half the eigenvalues need to be stored (the other half are complex conjugates)
  • Output is 2 * Re(y) instead of Re(y)
  • Provides 2x memory efficiency for state representation
This is the standard configuration for S5 models.

Architecture Comparison

OperationInput TypeOutput TypeUse Case
simplified_scan_fnComplexComplexCore S5 scan, flexible output
s5_inner_fnComplexRealComplete S5 layer with skip connection

Source Code

Source: lrnnx/ops/simplified_scan.py

Build docs developers (and LLMs) love