Overview
The simplified scan operations provide CUDA-accelerated implementations for S5-style structured state space models. These operations use a diagonal SSM formulation with complex-valued inputs and projection matrices for efficient computation.simplified_scan_fn
u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L)
Complex input tensor of shape
(batch, H, seqlen), dtype=complex64, where H is the hidden/input dimensionReal timestep tensor of shape
(batch, P, seqlen), dtype=float32, where P is the state dimensionComplex state matrix eigenvalues of shape
(P,) or (P, 1), dtype=complex64. These are the diagonal elements of the state transition matrixComplex projection matrix of shape
(P, H), dtype=complex64. Projects input from H-dimensional space to P-dimensional state spaceComplex projection matrix of shape
(H, P), dtype=complex64. Projects state from P-dimensional space back to H-dimensional outputOptional separate timestep for A discretization of shape
(batch, P, seqlen), dtype=float32. If provided, A is discretized using deltaA while B uses delta, enabling asynchronous state dynamicsWhether to return the last hidden state. If True, returns tuple
(output, last_state) where last_state has shape (batch, P), dtype=complex64Discretization method. Options:
"bilinear": Bilinear transform (Tustin’s method) - recommended for S5"zoh": Zero-order hold"dirac": Dirac delta (no delta scaling)
Complex output tensor of shape
(batch, H, seqlen), dtype=complex64. If return_last_state=True, returns tuple (output, last_state) where last_state has shape (batch, P), dtype=complex64Example
Discretization Methods
The discretization methods transform continuous-time state space parameters to discrete-time: Bilinear (recommended for S5):s5_inner_fn
- SSM scan:
x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t],y = C @ x - Conjugate symmetry:
y_real = (2 if conj_sym else 1) * Re(y) - Skip connection:
out = y_real + D * u.real
Complex input tensor of shape
(batch, H, seqlen), dtype=complex64Real timestep tensor of shape
(batch, P, seqlen), dtype=float32Complex eigenvalues tensor of shape
(P,) or (P, 1), dtype=complex64Complex projection matrix of shape
(P, H), dtype=complex64Complex projection matrix of shape
(H, P), dtype=complex64Real skip connection tensor of shape
(H,), dtype=float32, providing direct input-to-output connectionsOptional separate timestep for A discretization of shape
(batch, P, seqlen), dtype=float32Discretization method:
"bilinear", "zoh", or "dirac"If True, output is
2 * Re(y), leveraging conjugate symmetry. If False, output is Re(y). S5 models typically use conjugate symmetry for efficiencyReal output tensor of shape
(batch, H, seqlen), dtype=float32Example
Conjugate Symmetry
Whenconj_sym=True, the output uses conjugate symmetry:
- Only half the eigenvalues need to be stored (the other half are complex conjugates)
- Output is
2 * Re(y)instead ofRe(y) - Provides 2x memory efficiency for state representation
Architecture Comparison
| Operation | Input Type | Output Type | Use Case |
|---|---|---|---|
simplified_scan_fn | Complex | Complex | Core S5 scan, flexible output |
s5_inner_fn | Complex | Real | Complete S5 layer with skip connection |
Source Code
Source:lrnnx/ops/simplified_scan.py