Overview
Mamba is a Selective State Space Model (SSM) that supports optional event-based processing. When integration_timesteps is provided, it uses asymmetric discretization with separate dtA and dtB for event-driven processing. Otherwise, it uses standard Mamba discretization.
Key features:
- Input-dependent selectivity via learned ∆, B, and C parameters
- Hardware-efficient implementation with fused CUDA kernels
- Support for event-based/asynchronous discretization
- Causal 1D convolution for local context
- Efficient autoregressive inference with state caching
Import
Class Signature
Constructor
__init__
Model dimension (input/output dimension).
SSM state dimension (N). Controls the capacity of the state space.
Convolution kernel size for local context.
Expansion factor for inner dimension. Inner dimension =
d_model * expand.Rank for delta projection. If
"auto", uses ceil(d_model / 16).Minimum value for delta initialization.
Maximum value for delta initialization.
Initialization method for delta. Options:
"random" or "constant".Scale factor for dt initialization.
Floor value for dt initialization to prevent numerical instability.
Whether to use bias in the convolution layer.
Whether to use bias in linear projections.
Whether to use fused CUDA kernels when available. Significantly improves performance.
Layer index for multi-layer caching in stacked architectures.
Device for parameters. If
None, uses default device.Data type for parameters. If
None, uses default dtype.Discretization type. Options:
"mamba", "zoh", "bilinear", "dirac".Methods
forward
Input tensor of shape
(B, L, D) where:B= batch sizeL= sequence lengthD= model dimension (d_model)
Time intervals between events, shape
(B, L). When provided, uses asymmetric discretization with separate dtA and dtB for event-driven processing. This enables the model to handle non-uniform time intervals between sequence elements.Not currently used by Mamba. Kept for interface consistency.
Cache for autoregressive generation. If provided, must contain:
"conv_state": Convolution state tensor"lrnn_state": SSM state tensor"seqlen_offset": Current position in sequence
Output tensor of shape
(B, L, D).step
Input at current timestep, shape
(B, 1, D).Cache dictionary containing:
"conv_state": Convolution state, shape(B, D_inner, d_conv)"lrnn_state": SSM state, shape(B, D_inner, N)"seqlen_offset": Current position in sequence
Integration timestep for this step, shape
(B, 1) or (B,). When provided, uses event-based asymmetric discretization.Additional keyword arguments (unused).
A tuple containing:
- Output at current timestep, shape
(B, 1, D) - Updated cache dictionary
allocate_inference_cache
The batch size for inference.
Maximum sequence length. Not used by Mamba but kept for interface consistency.
Data type for allocated tensors. If
None, uses model’s parameter dtype.Additional arguments (unused).
Cache dictionary containing:
"conv_state": Convolution state, shape(B, D_inner, d_conv)"lrnn_state": SSM state, shape(B, D_inner, N)"seqlen_offset": Current position (initialized to 0)
