Skip to main content

Overview

lrnnx achieves high performance through custom CUDA kernels that implement efficient forward and backward passes for Linear RNN operations. The library includes three main kernel types optimized for different model architectures.
Custom CUDA kernels are compiled during installation, which can take up to 30 minutes depending on your system. The compilation happens automatically when you install lrnnx.

Kernel Types

1. Selective Scan Kernels

Located in csrc/selective_scan/, these kernels power selective state space models like Mamba, S6, and S7.
csrc/selective_scan/forward_kernels/
├── selective_scan_fp32_real_mamba.cu      # Mamba discretization
├── selective_scan_fp32_real_zoh.cu        # Zero-order hold
├── selective_scan_fp32_real_bilinear.cu   # Bilinear discretization
├── selective_scan_fp32_real_dirac.cu      # Dirac discretization
├── selective_scan_fp32_real_s7.cu         # S7 variant
├── selective_scan_fp32_real_rglru.cu      # RG-LRU variant
├── selective_scan_fp32_complex_mamba.cu   # Complex Mamba
├── selective_scan_fp32_complex_zoh.cu     # Complex ZOH
├── selective_scan_fp32_complex_bilinear.cu
└── selective_scan_fp32_complex_dirac.cu
Key Features:
  • Input-dependent state transitions (selective mechanism)
  • Multiple discretization schemes (Mamba, ZOH, Bilinear, Dirac)
  • Support for both real and complex valued computations
  • Optimized for time-varying models (LTV)
Used by: Mamba, S6, S7, RG-LRU

2. Simplified Scan Kernels

Located in csrc/simplified_scan/, these kernels implement efficient scans for simpler state space models.
csrc/simplified_scan/forward_kernels/
├── simplified_scan_fp32_zoh.cu
├── simplified_scan_fp32_bilinear.cu
└── simplified_scan_fp32_dirac.cu
Key Features:
  • Fixed (time-invariant) state transitions
  • Three discretization methods
  • Lower memory footprint than selective scan
  • Optimized for time-invariant models (LTI)
Used by: LRU, S5 (parallel scan variant)

3. Structured Kernels (S4)

Located in csrc/s4/, these implement specialized operations for structured state space models.
csrc/s4/
├── cauchy.cpp              # Python bindings
├── cauchy_cuda.cu          # Cauchy multiplication kernel
├── cauchy.py               # Python interface
├── vandermonde.py          # Vandermonde operations
├── map.h                   # Helper utilities
└── tuner.py                # Kernel auto-tuning
Key Features:
  • Cauchy matrix multiplication for structured matrices
  • Vandermonde matrix operations
  • Auto-tuning for optimal performance
  • Specialized for diagonal state space models
Used by: S4, S4D

Discretization Methods

The kernels support multiple discretization schemes for converting continuous-time systems to discrete-time:

Mamba Discretization

from lrnnx.models.ltv import Mamba

# Uses Mamba discretization by default
model = Mamba(d_model=64, d_state=16, discretization="mamba")

Zero-Order Hold (ZOH)

from lrnnx.models.ltv import Mamba

# Use ZOH discretization
model = Mamba(d_model=64, d_state=16, discretization="zoh")

Bilinear Transform

from lrnnx.models.ltv import Mamba

# Use bilinear discretization
model = Mamba(d_model=64, d_state=16, discretization="bilinear")

Dirac Delta

from lrnnx.models.ltv import Mamba

# Use Dirac discretization
model = Mamba(d_model=64, d_state=16, discretization="dirac")
Different discretization methods can affect model performance. Experiment with different schemes for your use case.

Kernel Architecture

Forward Pass

The forward kernels implement the core state space recurrence:
# Conceptual pseudocode of what the kernels compute
for t in range(seq_len):
    # State update
    state = A * state + B * input[t]
    # Output
    output[t] = C * state + D * input[t]
The CUDA kernels optimize this with:
  • Warp-level parallelism for batch and channel dimensions
  • Shared memory for coefficient caching
  • Register blocking for improved memory bandwidth
  • Fused operations to minimize memory traffic

Backward Pass

The backward kernels compute gradients with respect to all parameters:
  • dL/dA - Gradient for state transition matrix
  • dL/dB - Gradient for input projection
  • dL/dC - Gradient for output projection
  • dL/dD - Gradient for skip connection
  • dL/dx - Gradient for inputs
These use reverse-mode automatic differentiation optimized for recurrent operations.

Python Bindings

The CUDA kernels are exposed to Python through PyBind11:
// csrc/selective_scan/bindings.cpp
#include "selective_scan.h"
#include <torch/python.h>
#include <torch/extension.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fwd", &selective_scan_fwd, "Selective scan forward");
    m.def("bwd", &selective_scan_bwd, "Selective scan backward");
}
And used in Python like:
from lrnnx.ops.selective_scan import selective_scan_fn

# High-level API automatically dispatches to CUDA kernels
y = selective_scan_fn(
    x, dt, A, B, C, D,
    z=z,
    delta_bias=delta_bias,
    delta_softplus=True,
    discretization="mamba"
)

Performance Benefits

Custom CUDA kernels provide significant speedups over naive PyTorch implementations:
1

Fused Operations

Multiple operations are fused into single kernels, reducing memory bandwidth:
# Instead of separate operations:
dt = F.softplus(dt + bias)  # Kernel launch 1
dA = torch.exp(dt * A)       # Kernel launch 2
state = state * dA           # Kernel launch 3

# Single fused kernel does all three
selective_scan_fwd(...)      # One kernel launch
2

Optimized Memory Access

Kernels use shared memory and register blocking:
  • 10-100x faster than naive implementations
  • Optimized memory coalescing
  • Reduced global memory transactions
3

Recurrence Optimization

Special handling for sequential dependencies:
  • Warp-level synchronization
  • Efficient scan algorithms
  • Parallel prefix sums where applicable

Installation and Compilation

Kernels are automatically compiled during package installation:
# From PyPI (compiles kernels automatically)
pip install lrnnx --no-build-isolation

# From source
git clone https://github.com/SforAiDl/lrnnx.git
cd lrnnx
pip install -e . --no-build-isolation
Compilation time: The full installation can take ~30 minutes depending on:
  • Number of CPU cores available
  • CUDA toolkit version
  • Whether causal-conv1d is included
The library compiles kernels for multiple discretization schemes and data types, which is why compilation takes time.

Compilation Requirements

  • CUDA Toolkit: 11.8 or higher
  • C++ Compiler: GCC 7+ or Clang
  • PyTorch: Must be installed first with matching CUDA version
  • Python: 3.8 or higher

Checking Kernel Availability

Verify that kernels compiled successfully:
import torch
from lrnnx.ops.selective_scan import selective_scan_fn
from lrnnx.ops.simplified_scan import simplified_scan_fn

# Check if CUDA kernels are available
print("CUDA available:", torch.cuda.is_available())

# Try running a small forward pass
x = torch.randn(2, 10, 32, device="cuda")
dt = torch.randn(2, 32, 10, device="cuda")
A = torch.randn(32, 16, device="cuda")
B = torch.randn(2, 16, 10, device="cuda")
C = torch.randn(2, 16, 10, device="cuda")
D = torch.randn(32, device="cuda")

try:
    y = selective_scan_fn(x, dt, A, B, C, D, discretization="mamba")
    print("Selective scan kernel: OK")
except Exception as e:
    print(f"Selective scan kernel: FAILED - {e}")

CPU Fallback

For debugging or CPU-only environments, the library provides CPU implementations:
// csrc/selective_scan/selective_scan_cpu.cpp
// Provides slower but correct CPU reference implementations
# Automatically uses CPU version when CUDA is unavailable
model = Mamba(d_model=64, d_state=16).cpu()  # Uses CPU kernels
CPU implementations are primarily for testing and debugging. For production use, always use CUDA.

Advanced: Kernel Auto-Tuning

The S4 Cauchy kernels include an auto-tuning system for optimal performance:
# Auto-tune Cauchy kernel for your hardware
cd csrc/s4
./tune_cauchy.sh
This finds optimal thread block sizes and memory configurations for your specific GPU.

Source Code Organization

csrc/
├── common.h                    # Shared utilities and macros
├── reverse_scan.cuh            # Reverse scan for backward pass
├── selective_scan/
   ├── selective_scan_cpu.cpp  # CPU reference implementation
   ├── bindings.cpp            # Python bindings
   ├── forward_kernels/        # 10 forward kernel variants
   └── backward_kernels/       # 10 backward kernel variants
├── simplified_scan/
   ├── simplified_scan_cpu.cpp # CPU reference implementation
   ├── bindings.cpp            # Python bindings
   ├── forward_kernels/        # 3 forward kernel variants
   └── backward_kernels/       # 3 backward kernel variants
└── s4/
    ├── cauchy.cpp              # Cauchy bindings
    ├── cauchy_cuda.cu          # Cauchy CUDA kernel
    ├── cauchy.py               # Python interface
    └── tuner.py                # Auto-tuning system

Next Steps

Training Guide

Learn how to train models using these kernels

API Reference

Explore the Python APIs built on these kernels

Build docs developers (and LLMs) love