Overview
The selective scan operations provide CUDA-accelerated implementations of the Mamba SSM (Structured State Space Model) scan algorithm. These operations are the core computational primitives for Mamba models.selective_scan_fn
Input tensor of shape
(batch, dim, seqlen)Delta tensor of shape
(batch, dim, seqlen) controlling the discretization timestepsState matrix A of shape
(dim, dstate) - typically negative real values for stabilityInput matrix B. Can be:
(batch, dstate, seqlen)for time-varying B(dim, dstate)for time-invariant B
Output matrix C. Can be:
(batch, dstate, seqlen)for time-varying C(dim, dstate)for time-invariant C
Skip connection vector of shape
(dim,) for direct input-to-output connectionsGating tensor of shape
(batch, dim, seqlen) for SiLU gating (used in Mamba-2)Bias for delta of shape
(dim,) added before discretizationAsymmetric delta for A discretization of shape
(batch, dim, seqlen). When provided, A is discretized using deltaA while B uses delta, enabling asynchronous/event-based processingWhether to apply softplus activation to delta before discretization
If True, returns
(out, last_state) where last_state has shape (batch, dim, dstate). Note that gradients of the last state are not propagated in backward passDiscretization method to use. Options:
"mamba": Standard Mamba discretization (zero-order hold variant)"zoh": Zero-order hold discretization"bilinear": Bilinear transform (Tustin’s method)"dirac": Dirac delta (no delta scaling for B)
The output tensor of shape
(batch, dim, seqlen), or tuple (output, last_state) if return_last_state=TrueExample
Discretization Methods
The discretization method controls how continuous-time SSM parameters are converted to discrete-time: Mamba (default):mamba_inner_fn
Input tensor of shape
(batch, 2*dim, seqlen) containing concatenated x and z (gating) inputsConv1d weights of shape
(dim, 1, kernel_size) for causal convolutionConv1d biases of shape
(dim,) or NoneProjection weights for B, C, delta of shape
(delta_rank + 2*dstate, dim)Projection weights for delta of shape
(dim, delta_rank)Output projection weights of shape
(d_model, dim)Output projection biases of shape
(d_model,) or NoneState matrix A of shape
(dim, dstate)State matrix B. If None, B is computed from input projections (variable B)
State matrix C. If None, C is computed from input projections (variable C)
Skip connection matrix D of shape
(dim,)Bias for delta of shape
(dim,)Bias for B projection of shape
(dstate,)Bias for C projection of shape
(dstate,)Whether to apply softplus to delta
Gradient checkpointing level (0 or 1). Level 1 recomputes conv1d and delta in backward pass to save memory
RMS normalization weights for B of shape
(dstate,)RMS normalization weights for C of shape
(dstate,)RMS normalization weights for dt of shape
(dim,)Epsilon for RMS normalization
The projected output tensor of shape
(batch, seqlen, d_model)Example
Performance Notes
- This function fuses multiple operations into a single CUDA kernel for better performance
- Gradient checkpointing (
checkpoint_lvl=1) trades computation for memory - Requires the
causal-conv1dpackage to be installed
Source Code
Source:lrnnx/ops/selective_scan.py