Overview
The S4 kernel operations provide efficient convolution kernel computation for S4 (Structured State Space Sequence) models. These operations support both diagonal (S4D) and diagonal plus low-rank (DPLR) parameterizations of state space models.S4Kernel
Model dimension
Maximum sequence length. If None, kernel length is determined dynamically
Number of output channels/heads
Configuration dictionary containing:
- Parameter references:
A_real,A_imag,B,C,inv_dt,P(nn.Parameters owned by parent S4 model) - Dimensions:
N,H,channels,rank,repeat - Flags:
dt_fast,real_transform,imag_transform,dt_transform,is_real,deterministic,verbose
forward
Compute SSM convolution kernel.State tensor to augment the kernel computation
Sampling rate for the kernel
Sequence length. If None, uses
l_max / rateConvolution kernel of shape
(channels, H, L)Kernel state if
state was provided, otherwise Nonestep
Perform single recurrent step.Input tensor
Current state tensor
Output tensor (real part)
Updated state tensor
default_state
Create default zero-initialized state.Variable length argument list for batch dimensions
Zero-initialized state tensor of shape
(*batch_shape, H, N) or (*batch_shape, H, 2*N) depending on step modeExample
S4DKernel
Model dimension
Maximum sequence length
Number of output channels
Configuration dictionary with additional S4D-specific key:
disc: Discretization method (“zoh” or “bilinear”)
forward
Compute SSM convolution kernel.Sequence length
State tensor
Sampling rate
Convolution kernel of shape
(channels, H, L)Kernel state if provided, otherwise None
step
Single step operation.Input tensor of shape
(B, H)Current state tensor of shape
(B, H, N)Output tensor (scaled by 2 for conjugate symmetry)
Updated state tensor
Example
Cauchy Operations
Cauchy kernel function with signature
(v, z, w) -> torch.Tensorcauchy_naive
Naive PyTorch fallback for Cauchy matrix multiplication.Input tensor of shape
(..., N)Evaluation points tensor of shape
(..., L)Poles tensor of shape
(..., N)The sum
v/(z-w) of shape (..., L)Example
Vandermonde Operations
get_vandermonde_kernel
Vandermonde kernel function with signature
(v, x, L) -> torch.Tensorlog_vandermonde_naive
Naive PyTorch fallback for log Vandermonde multiplication.Input tensor of shape
(..., N)Log-space base tensor of shape
(..., N)Sequence length
Whether to use conjugate symmetry
The sum
v * x^l for l in [0, L), shape (..., L)get_vandermonde_transpose_kernel
Transposed Vandermonde kernel with signature
(u, v, x, L) -> torch.Tensorlog_vandermonde_transpose_naive
Naive PyTorch fallback for transposed log Vandermonde multiplication.Input tensor of shape
(..., L)Input tensor of shape
(..., N)Log-space base tensor of shape
(..., N)Sequence length
Output tensor of shape
(..., N)Example
Performance Notes
- CUDA Extension: Provides fastest performance when available
- KeOps: Falls back to KeOps for GPU acceleration if CUDA extension not available
- PyTorch Fallback: Pure PyTorch implementation used when neither CUDA nor KeOps available
Source Code
- Kernel Interface:
lrnnx/ops/s4_kernel_interface.py - Utilities:
lrnnx/ops/s4_utils.py
