Skip to main content

Overview

The S4 kernel operations provide efficient convolution kernel computation for S4 (Structured State Space Sequence) models. These operations support both diagonal (S4D) and diagonal plus low-rank (DPLR) parameterizations of state space models.

S4Kernel

from lrnnx.ops.s4_kernel_interface import S4Kernel
SSM kernel for diagonal + low rank (DPLR) state matrices. Computes pure convolution operations for efficient sequence modeling.
d_model
int
required
Model dimension
l_max
int | None
required
Maximum sequence length. If None, kernel length is determined dynamically
channels
int
required
Number of output channels/heads
param_config
dict
required
Configuration dictionary containing:
  • Parameter references: A_real, A_imag, B, C, inv_dt, P (nn.Parameters owned by parent S4 model)
  • Dimensions: N, H, channels, rank, repeat
  • Flags: dt_fast, real_transform, imag_transform, dt_transform, is_real, deterministic, verbose

forward

Compute SSM convolution kernel.
state
torch.Tensor
default:"None"
State tensor to augment the kernel computation
rate
float
default:"1.0"
Sampling rate for the kernel
L
int
default:"None"
Sequence length. If None, uses l_max / rate
k_B
torch.Tensor
Convolution kernel of shape (channels, H, L)
k_state
torch.Tensor | None
Kernel state if state was provided, otherwise None

step

Perform single recurrent step.
u
torch.Tensor
required
Input tensor
state
torch.Tensor
required
Current state tensor
y
torch.Tensor
Output tensor (real part)
new_state
torch.Tensor
Updated state tensor

default_state

Create default zero-initialized state.
*batch_shape
tuple
required
Variable length argument list for batch dimensions
state
torch.Tensor
Zero-initialized state tensor of shape (*batch_shape, H, N) or (*batch_shape, H, 2*N) depending on step mode

Example

import torch
from lrnnx.ops.s4_kernel_interface import S4Kernel

# Assume params are created by parent S4 model
param_config = {
    'A_real': A_real_param,
    'A_imag': A_imag_param,
    'B': B_param,
    'C': C_param,
    'P': P_param,
    'inv_dt': inv_dt_param,
    'N': 64,
    'H': 256,
    'channels': 1,
    'rank': 2,
    'repeat': 1,
    'dt_fast': False,
    'real_transform': 'exp',
    'imag_transform': 'none',
    'dt_transform': 'exp',
    'is_real': False,
    'deterministic': False,
    'verbose': False
}

kernel = S4Kernel(
    d_model=256,
    l_max=1024,
    channels=1,
    param_config=param_config
)

# Compute convolution kernel
k, _ = kernel(L=512)
print(k.shape)  # (1, 256, 512)

# Use for recurrent stepping
kernel._setup_step(mode='linear')
state = kernel.default_state(batch_size)
u = torch.randn(batch_size, H, device='cuda')
y, new_state = kernel.step(u, state)

S4DKernel

from lrnnx.ops.s4_kernel_interface import S4DKernel
SSM kernel using diagonal state matrix (S4D model). Simpler and more efficient than DPLR for many tasks.
d_model
int
required
Model dimension
l_max
int | None
required
Maximum sequence length
channels
int
required
Number of output channels
param_config
dict
required
Configuration dictionary with additional S4D-specific key:
  • disc: Discretization method (“zoh” or “bilinear”)

forward

Compute SSM convolution kernel.
L
int
required
Sequence length
state
torch.Tensor
default:"None"
State tensor
rate
float
default:"1.0"
Sampling rate
K
torch.Tensor
Convolution kernel of shape (channels, H, L)
K_state
torch.Tensor | None
Kernel state if provided, otherwise None

step

Single step operation.
u
torch.Tensor
required
Input tensor of shape (B, H)
state
torch.Tensor
required
Current state tensor of shape (B, H, N)
y
torch.Tensor
Output tensor (scaled by 2 for conjugate symmetry)
next_state
torch.Tensor
Updated state tensor

Example

import torch
from lrnnx.ops.s4_kernel_interface import S4DKernel

param_config = {
    'A_real': A_real_param,
    'A_imag': A_imag_param,
    'B': B_param,
    'C': C_param,
    'inv_dt': inv_dt_param,
    'N': 64,
    'H': 256,
    'channels': 1,
    'rank': 1,
    'repeat': 1,
    'dt_fast': False,
    'real_transform': 'exp',
    'imag_transform': 'none',
    'dt_transform': 'exp',
    'is_real': False,
    'deterministic': False,
    'verbose': False,
    'disc': 'zoh'  # S4D-specific
}

kernel = S4DKernel(
    d_model=256,
    l_max=1024,
    channels=1,
    param_config=param_config
)

# Compute convolution kernel
K, _ = kernel(L=512)
print(K.shape)  # (1, 256, 512)

Cauchy Operations

from lrnnx.ops.s4_utils import get_cauchy_kernel
Returns the best available Cauchy multiplication function (CUDA extension, KeOps, or PyTorch fallback).
cauchy_fn
callable
Cauchy kernel function with signature (v, z, w) -> torch.Tensor

cauchy_naive

Naive PyTorch fallback for Cauchy matrix multiplication.
v
torch.Tensor
required
Input tensor of shape (..., N)
z
torch.Tensor
required
Evaluation points tensor of shape (..., L)
w
torch.Tensor
required
Poles tensor of shape (..., N)
output
torch.Tensor
The sum v/(z-w) of shape (..., L)

Example

import torch
from lrnnx.ops.s4_utils import get_cauchy_kernel

cauchy_fn = get_cauchy_kernel()

N, L = 64, 256
v = torch.randn(N, dtype=torch.complex64, device='cuda')
z = torch.randn(L, dtype=torch.complex64, device='cuda')
w = torch.randn(N, dtype=torch.complex64, device='cuda')

result = cauchy_fn(v, z, w)
print(result.shape)  # (256,)

Vandermonde Operations

from lrnnx.ops.s4_utils import get_vandermonde_kernel, get_vandermonde_transpose_kernel
Returns the best available Vandermonde multiplication functions.

get_vandermonde_kernel

vandermonde_fn
callable
Vandermonde kernel function with signature (v, x, L) -> torch.Tensor

log_vandermonde_naive

Naive PyTorch fallback for log Vandermonde multiplication.
v
torch.Tensor
required
Input tensor of shape (..., N)
x
torch.Tensor
required
Log-space base tensor of shape (..., N)
L
int
required
Sequence length
conj
bool
default:"True"
Whether to use conjugate symmetry
output
torch.Tensor
The sum v * x^l for l in [0, L), shape (..., L)

get_vandermonde_transpose_kernel

vandermonde_transpose_fn
callable
Transposed Vandermonde kernel with signature (u, v, x, L) -> torch.Tensor

log_vandermonde_transpose_naive

Naive PyTorch fallback for transposed log Vandermonde multiplication.
u
torch.Tensor
required
Input tensor of shape (..., L)
v
torch.Tensor
required
Input tensor of shape (..., N)
x
torch.Tensor
required
Log-space base tensor of shape (..., N)
L
int
required
Sequence length
output
torch.Tensor
Output tensor of shape (..., N)

Example

import torch
from lrnnx.ops.s4_utils import get_vandermonde_kernel, get_vandermonde_transpose_kernel

vandermonde_fn = get_vandermonde_kernel()
vandermonde_T_fn = get_vandermonde_transpose_kernel()

N, L = 64, 256
v = torch.randn(N, dtype=torch.complex64, device='cuda')
x = torch.randn(N, dtype=torch.complex64, device='cuda')

# Forward Vandermonde
result = vandermonde_fn(v, x, L)
print(result.shape)  # (256,)

# Transposed Vandermonde
u = torch.randn(L, dtype=torch.complex64, device='cuda')
result_T = vandermonde_T_fn(u, v, x, L)
print(result_T.shape)  # (64,)

Performance Notes

  • CUDA Extension: Provides fastest performance when available
  • KeOps: Falls back to KeOps for GPU acceleration if CUDA extension not available
  • PyTorch Fallback: Pure PyTorch implementation used when neither CUDA nor KeOps available
The library automatically selects the best available implementation at runtime.

Source Code

  • Kernel Interface: lrnnx/ops/s4_kernel_interface.py
  • Utilities: lrnnx/ops/s4_utils.py

Build docs developers (and LLMs) love