Skip to main content

Overview

The selective scan operations provide CUDA-accelerated implementations of the Mamba SSM (Structured State Space Model) scan algorithm. These operations are the core computational primitives for Mamba models.

selective_scan_fn

from lrnnx.ops.selective_scan import selective_scan_fn
Apply the CUDA selective scan function with optional gating and discretization methods.
u
torch.Tensor
required
Input tensor of shape (batch, dim, seqlen)
delta
torch.Tensor
required
Delta tensor of shape (batch, dim, seqlen) controlling the discretization timesteps
A
torch.Tensor
required
State matrix A of shape (dim, dstate) - typically negative real values for stability
B
torch.Tensor
required
Input matrix B. Can be:
  • (batch, dstate, seqlen) for time-varying B
  • (dim, dstate) for time-invariant B
C
torch.Tensor
required
Output matrix C. Can be:
  • (batch, dstate, seqlen) for time-varying C
  • (dim, dstate) for time-invariant C
D
torch.Tensor
default:"None"
Skip connection vector of shape (dim,) for direct input-to-output connections
z
torch.Tensor
default:"None"
Gating tensor of shape (batch, dim, seqlen) for SiLU gating (used in Mamba-2)
delta_bias
torch.Tensor
default:"None"
Bias for delta of shape (dim,) added before discretization
deltaA
torch.Tensor
default:"None"
Asymmetric delta for A discretization of shape (batch, dim, seqlen). When provided, A is discretized using deltaA while B uses delta, enabling asynchronous/event-based processing
delta_softplus
bool
default:"False"
Whether to apply softplus activation to delta before discretization
return_last_state
bool
default:"False"
If True, returns (out, last_state) where last_state has shape (batch, dim, dstate). Note that gradients of the last state are not propagated in backward pass
discretization
str
default:"'mamba'"
Discretization method to use. Options:
  • "mamba": Standard Mamba discretization (zero-order hold variant)
  • "zoh": Zero-order hold discretization
  • "bilinear": Bilinear transform (Tustin’s method)
  • "dirac": Dirac delta (no delta scaling for B)
output
torch.Tensor | tuple[torch.Tensor, torch.Tensor]
The output tensor of shape (batch, dim, seqlen), or tuple (output, last_state) if return_last_state=True

Example

import torch
from lrnnx.ops.selective_scan import selective_scan_fn

batch, dim, seqlen, dstate = 2, 64, 128, 16

# Create input tensors
u = torch.randn(batch, dim, seqlen, device='cuda')
delta = torch.randn(batch, dim, seqlen, device='cuda')
A = -torch.rand(dim, dstate, device='cuda')  # Negative for stability
B = torch.randn(batch, dstate, seqlen, device='cuda')
C = torch.randn(batch, dstate, seqlen, device='cuda')
D = torch.randn(dim, device='cuda')

# Run selective scan
output = selective_scan_fn(
    u, delta, A, B, C, D,
    delta_softplus=True,
    discretization="mamba"
)

print(output.shape)  # (2, 64, 128)

Discretization Methods

The discretization method controls how continuous-time SSM parameters are converted to discrete-time: Mamba (default):
A_bar = exp(delta * A)
B_bar = delta * B
Zero-Order Hold (ZOH):
A_bar = exp(delta * A)
B_bar = (exp(delta * A) - 1) / A * B
Bilinear:
A_bar = (1 + 0.5*delta*A) / (1 - 0.5*delta*A)
B_bar = delta / (1 - 0.5*delta*A) * B
Dirac:
A_bar = exp(delta * A)
B_bar = B  # No delta scaling

mamba_inner_fn

from lrnnx.ops.selective_scan import mamba_inner_fn
Apply the fused Mamba inner function, which combines causal convolution, projections, selective scan, and output projection into a single optimized operation.
xz
torch.Tensor
required
Input tensor of shape (batch, 2*dim, seqlen) containing concatenated x and z (gating) inputs
conv1d_weight
torch.Tensor
required
Conv1d weights of shape (dim, 1, kernel_size) for causal convolution
conv1d_bias
torch.Tensor | None
required
Conv1d biases of shape (dim,) or None
x_proj_weight
torch.Tensor
required
Projection weights for B, C, delta of shape (delta_rank + 2*dstate, dim)
delta_proj_weight
torch.Tensor
required
Projection weights for delta of shape (dim, delta_rank)
out_proj_weight
torch.Tensor
required
Output projection weights of shape (d_model, dim)
out_proj_bias
torch.Tensor | None
required
Output projection biases of shape (d_model,) or None
A
torch.Tensor
required
State matrix A of shape (dim, dstate)
B
torch.Tensor
default:"None"
State matrix B. If None, B is computed from input projections (variable B)
C
torch.Tensor
default:"None"
State matrix C. If None, C is computed from input projections (variable C)
D
torch.Tensor
default:"None"
Skip connection matrix D of shape (dim,)
delta_bias
torch.Tensor
default:"None"
Bias for delta of shape (dim,)
B_proj_bias
torch.Tensor
default:"None"
Bias for B projection of shape (dstate,)
C_proj_bias
torch.Tensor
default:"None"
Bias for C projection of shape (dstate,)
delta_softplus
bool
default:"True"
Whether to apply softplus to delta
checkpoint_lvl
int
default:"1"
Gradient checkpointing level (0 or 1). Level 1 recomputes conv1d and delta in backward pass to save memory
b_rms_weight
torch.Tensor
default:"None"
RMS normalization weights for B of shape (dstate,)
c_rms_weight
torch.Tensor
default:"None"
RMS normalization weights for C of shape (dstate,)
dt_rms_weight
torch.Tensor
default:"None"
RMS normalization weights for dt of shape (dim,)
b_c_dt_rms_eps
float
default:"1e-6"
Epsilon for RMS normalization
output
torch.Tensor
The projected output tensor of shape (batch, seqlen, d_model)

Example

import torch
from lrnnx.ops.selective_scan import mamba_inner_fn

batch, d_model, dim, seqlen = 2, 512, 256, 128
dstate, delta_rank = 16, 32
kernel_size = 4

# Create input and weight tensors
xz = torch.randn(batch, 2*dim, seqlen, device='cuda')
conv1d_weight = torch.randn(dim, 1, kernel_size, device='cuda')
conv1d_bias = torch.randn(dim, device='cuda')
x_proj_weight = torch.randn(delta_rank + 2*dstate, dim, device='cuda')
delta_proj_weight = torch.randn(dim, delta_rank, device='cuda')
out_proj_weight = torch.randn(d_model, dim, device='cuda')
out_proj_bias = torch.randn(d_model, device='cuda')
A = -torch.rand(dim, dstate, device='cuda')

# Run fused Mamba inner function
output = mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias,
    x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, delta_softplus=True
)

print(output.shape)  # (2, 128, 512)

Performance Notes

  • This function fuses multiple operations into a single CUDA kernel for better performance
  • Gradient checkpointing (checkpoint_lvl=1) trades computation for memory
  • Requires the causal-conv1d package to be installed

Source Code

Source: lrnnx/ops/selective_scan.py

Build docs developers (and LLMs) love