Overview
lrnnx achieves high performance through custom CUDA kernels that implement efficient forward and backward passes for Linear RNN operations. The library includes three main kernel types optimized for different model architectures.
Custom CUDA kernels are compiled during installation, which can take up to 30 minutes depending on your system. The compilation happens automatically when you install lrnnx.
Kernel Types
1. Selective Scan Kernels
Located in csrc/selective_scan/, these kernels power selective state space models like Mamba , S6 , and S7 .
Forward Kernels
Backward Kernels
csrc/selective_scan/forward_kernels/
├── selective_scan_fp32_real_mamba.cu # Mamba discretization
├── selective_scan_fp32_real_zoh.cu # Zero-order hold
├── selective_scan_fp32_real_bilinear.cu # Bilinear discretization
├── selective_scan_fp32_real_dirac.cu # Dirac discretization
├── selective_scan_fp32_real_s7.cu # S7 variant
├── selective_scan_fp32_real_rglru.cu # RG-LRU variant
├── selective_scan_fp32_complex_mamba.cu # Complex Mamba
├── selective_scan_fp32_complex_zoh.cu # Complex ZOH
├── selective_scan_fp32_complex_bilinear.cu
└── selective_scan_fp32_complex_dirac.cu
Key Features:
Input-dependent state transitions (selective mechanism)
Multiple discretization schemes (Mamba, ZOH, Bilinear, Dirac)
Support for both real and complex valued computations
Optimized for time-varying models (LTV)
Used by: Mamba, S6, S7, RG-LRU
2. Simplified Scan Kernels
Located in csrc/simplified_scan/, these kernels implement efficient scans for simpler state space models.
Forward Kernels
Backward Kernels
csrc/simplified_scan/forward_kernels/
├── simplified_scan_fp32_zoh.cu
├── simplified_scan_fp32_bilinear.cu
└── simplified_scan_fp32_dirac.cu
Key Features:
Fixed (time-invariant) state transitions
Three discretization methods
Lower memory footprint than selective scan
Optimized for time-invariant models (LTI)
Used by: LRU, S5 (parallel scan variant)
3. Structured Kernels (S4)
Located in csrc/s4/, these implement specialized operations for structured state space models.
csrc/s4/
├── cauchy.cpp # Python bindings
├── cauchy_cuda.cu # Cauchy multiplication kernel
├── cauchy.py # Python interface
├── vandermonde.py # Vandermonde operations
├── map.h # Helper utilities
└── tuner.py # Kernel auto-tuning
Key Features:
Cauchy matrix multiplication for structured matrices
Vandermonde matrix operations
Auto-tuning for optimal performance
Specialized for diagonal state space models
Used by: S4, S4D
Discretization Methods
The kernels support multiple discretization schemes for converting continuous-time systems to discrete-time:
Mamba Discretization
from lrnnx.models.ltv import Mamba
# Uses Mamba discretization by default
model = Mamba( d_model = 64 , d_state = 16 , discretization = "mamba" )
Zero-Order Hold (ZOH)
from lrnnx.models.ltv import Mamba
# Use ZOH discretization
model = Mamba( d_model = 64 , d_state = 16 , discretization = "zoh" )
from lrnnx.models.ltv import Mamba
# Use bilinear discretization
model = Mamba( d_model = 64 , d_state = 16 , discretization = "bilinear" )
Dirac Delta
from lrnnx.models.ltv import Mamba
# Use Dirac discretization
model = Mamba( d_model = 64 , d_state = 16 , discretization = "dirac" )
Different discretization methods can affect model performance. Experiment with different schemes for your use case.
Kernel Architecture
Forward Pass
The forward kernels implement the core state space recurrence:
# Conceptual pseudocode of what the kernels compute
for t in range (seq_len):
# State update
state = A * state + B * input [t]
# Output
output[t] = C * state + D * input [t]
The CUDA kernels optimize this with:
Warp-level parallelism for batch and channel dimensions
Shared memory for coefficient caching
Register blocking for improved memory bandwidth
Fused operations to minimize memory traffic
Backward Pass
The backward kernels compute gradients with respect to all parameters:
dL/dA - Gradient for state transition matrix
dL/dB - Gradient for input projection
dL/dC - Gradient for output projection
dL/dD - Gradient for skip connection
dL/dx - Gradient for inputs
These use reverse-mode automatic differentiation optimized for recurrent operations.
Python Bindings
The CUDA kernels are exposed to Python through PyBind11:
// csrc/selective_scan/bindings.cpp
#include "selective_scan.h"
#include <torch/python.h>
#include <torch/extension.h>
PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
m . def ( "fwd" , & selective_scan_fwd, "Selective scan forward" );
m . def ( "bwd" , & selective_scan_bwd, "Selective scan backward" );
}
And used in Python like:
from lrnnx.ops.selective_scan import selective_scan_fn
# High-level API automatically dispatches to CUDA kernels
y = selective_scan_fn(
x, dt, A, B, C, D,
z = z,
delta_bias = delta_bias,
delta_softplus = True ,
discretization = "mamba"
)
Custom CUDA kernels provide significant speedups over naive PyTorch implementations:
Fused Operations
Multiple operations are fused into single kernels, reducing memory bandwidth: # Instead of separate operations:
dt = F.softplus(dt + bias) # Kernel launch 1
dA = torch.exp(dt * A) # Kernel launch 2
state = state * dA # Kernel launch 3
# Single fused kernel does all three
selective_scan_fwd( ... ) # One kernel launch
Optimized Memory Access
Kernels use shared memory and register blocking:
10-100x faster than naive implementations
Optimized memory coalescing
Reduced global memory transactions
Recurrence Optimization
Special handling for sequential dependencies:
Warp-level synchronization
Efficient scan algorithms
Parallel prefix sums where applicable
Installation and Compilation
Kernels are automatically compiled during package installation:
# From PyPI (compiles kernels automatically)
pip install lrnnx --no-build-isolation
# From source
git clone https://github.com/SforAiDl/lrnnx.git
cd lrnnx
pip install -e . --no-build-isolation
Compilation time: The full installation can take ~30 minutes depending on:
Number of CPU cores available
CUDA toolkit version
Whether causal-conv1d is included
The library compiles kernels for multiple discretization schemes and data types, which is why compilation takes time.
Compilation Requirements
CUDA Toolkit: 11.8 or higher
C++ Compiler: GCC 7+ or Clang
PyTorch: Must be installed first with matching CUDA version
Python: 3.8 or higher
Checking Kernel Availability
Verify that kernels compiled successfully:
import torch
from lrnnx.ops.selective_scan import selective_scan_fn
from lrnnx.ops.simplified_scan import simplified_scan_fn
# Check if CUDA kernels are available
print ( "CUDA available:" , torch.cuda.is_available())
# Try running a small forward pass
x = torch.randn( 2 , 10 , 32 , device = "cuda" )
dt = torch.randn( 2 , 32 , 10 , device = "cuda" )
A = torch.randn( 32 , 16 , device = "cuda" )
B = torch.randn( 2 , 16 , 10 , device = "cuda" )
C = torch.randn( 2 , 16 , 10 , device = "cuda" )
D = torch.randn( 32 , device = "cuda" )
try :
y = selective_scan_fn(x, dt, A, B, C, D, discretization = "mamba" )
print ( "Selective scan kernel: OK" )
except Exception as e:
print ( f "Selective scan kernel: FAILED - { e } " )
CPU Fallback
For debugging or CPU-only environments, the library provides CPU implementations:
// csrc/selective_scan/selective_scan_cpu.cpp
// Provides slower but correct CPU reference implementations
# Automatically uses CPU version when CUDA is unavailable
model = Mamba( d_model = 64 , d_state = 16 ).cpu() # Uses CPU kernels
CPU implementations are primarily for testing and debugging. For production use, always use CUDA.
Advanced: Kernel Auto-Tuning
The S4 Cauchy kernels include an auto-tuning system for optimal performance:
# Auto-tune Cauchy kernel for your hardware
cd csrc/s4
./tune_cauchy.sh
This finds optimal thread block sizes and memory configurations for your specific GPU.
Source Code Organization
csrc/
├── common.h # Shared utilities and macros
├── reverse_scan.cuh # Reverse scan for backward pass
├── selective_scan/
│ ├── selective_scan_cpu.cpp # CPU reference implementation
│ ├── bindings.cpp # Python bindings
│ ├── forward_kernels/ # 10 forward kernel variants
│ └── backward_kernels/ # 10 backward kernel variants
├── simplified_scan/
│ ├── simplified_scan_cpu.cpp # CPU reference implementation
│ ├── bindings.cpp # Python bindings
│ ├── forward_kernels/ # 3 forward kernel variants
│ └── backward_kernels/ # 3 backward kernel variants
└── s4/
├── cauchy.cpp # Cauchy bindings
├── cauchy_cuda.cu # Cauchy CUDA kernel
├── cauchy.py # Python interface
└── tuner.py # Auto-tuning system
Next Steps
Training Guide Learn how to train models using these kernels
API Reference Explore the Python APIs built on these kernels