Skip to main content

Overview

Mamba is a Selective State Space Model (SSM) that supports optional event-based processing. When integration_timesteps is provided, it uses asymmetric discretization with separate dtA and dtB for event-driven processing. Otherwise, it uses standard Mamba discretization. Key features:
  • Input-dependent selectivity via learned ∆, B, and C parameters
  • Hardware-efficient implementation with fused CUDA kernels
  • Support for event-based/asynchronous discretization
  • Causal 1D convolution for local context
  • Efficient autoregressive inference with state caching

Import

from lrnnx.models.ltv import Mamba

Class Signature

class Mamba(LTV_LRNN)

Constructor

__init__

def __init__(
    self,
    d_model: int,
    d_state: int = 16,
    d_conv: int = 4,
    expand: int = 2,
    dt_rank: Union[int, str] = "auto",
    dt_min: float = 0.001,
    dt_max: float = 0.1,
    dt_init: str = "random",
    dt_scale: float = 1.0,
    dt_init_floor: float = 1e-4,
    conv_bias: bool = True,
    bias: bool = False,
    use_fast_path: bool = True,
    layer_idx: Optional[int] = None,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    discretization: str = "mamba",
)
Initialize Mamba model.
d_model
int
required
Model dimension (input/output dimension).
d_state
int
default:"16"
SSM state dimension (N). Controls the capacity of the state space.
d_conv
int
default:"4"
Convolution kernel size for local context.
expand
int
default:"2"
Expansion factor for inner dimension. Inner dimension = d_model * expand.
dt_rank
Union[int, str]
default:"'auto'"
Rank for delta projection. If "auto", uses ceil(d_model / 16).
dt_min
float
default:"0.001"
Minimum value for delta initialization.
dt_max
float
default:"0.1"
Maximum value for delta initialization.
dt_init
str
default:"'random'"
Initialization method for delta. Options: "random" or "constant".
dt_scale
float
default:"1.0"
Scale factor for dt initialization.
dt_init_floor
float
default:"1e-4"
Floor value for dt initialization to prevent numerical instability.
conv_bias
bool
default:"True"
Whether to use bias in the convolution layer.
bias
bool
default:"False"
Whether to use bias in linear projections.
use_fast_path
bool
default:"True"
Whether to use fused CUDA kernels when available. Significantly improves performance.
layer_idx
int | None
default:"None"
Layer index for multi-layer caching in stacked architectures.
device
torch.device | None
default:"None"
Device for parameters. If None, uses default device.
dtype
torch.dtype | None
default:"None"
Data type for parameters. If None, uses default dtype.
discretization
str
default:"'mamba'"
Discretization type. Options: "mamba", "zoh", "bilinear", "dirac".

Methods

forward

def forward(
    self,
    hidden_states: torch.Tensor,
    integration_timesteps: Optional[torch.Tensor] = None,
    lengths: Optional[torch.Tensor] = None,
    inference_cache: Optional[Dict[str, Any]] = None,
) -> torch.Tensor
Forward pass through Mamba.
hidden_states
torch.Tensor
Input tensor of shape (B, L, D) where:
  • B = batch size
  • L = sequence length
  • D = model dimension (d_model)
integration_timesteps
torch.Tensor | None
default:"None"
Time intervals between events, shape (B, L). When provided, uses asymmetric discretization with separate dtA and dtB for event-driven processing. This enables the model to handle non-uniform time intervals between sequence elements.
lengths
torch.Tensor | None
default:"None"
Not currently used by Mamba. Kept for interface consistency.
inference_cache
Dict[str, Any] | None
default:"None"
Cache for autoregressive generation. If provided, must contain:
  • "conv_state": Convolution state tensor
  • "lrnn_state": SSM state tensor
  • "seqlen_offset": Current position in sequence
output
torch.Tensor
Output tensor of shape (B, L, D).

step

def step(
    self,
    x: torch.Tensor,
    inference_cache: Dict[str, Any],
    integration_timesteps: Optional[torch.Tensor] = None,
    **kwargs,
) -> Tuple[torch.Tensor, Dict[str, Any]]
Performs a single recurrent step of Mamba for autoregressive inference.
x
torch.Tensor
Input at current timestep, shape (B, 1, D).
inference_cache
Dict[str, Any]
Cache dictionary containing:
  • "conv_state": Convolution state, shape (B, D_inner, d_conv)
  • "lrnn_state": SSM state, shape (B, D_inner, N)
  • "seqlen_offset": Current position in sequence
integration_timesteps
torch.Tensor | None
default:"None"
Integration timestep for this step, shape (B, 1) or (B,). When provided, uses event-based asymmetric discretization.
**kwargs
Any
Additional keyword arguments (unused).
output
Tuple[torch.Tensor, Dict[str, Any]]
A tuple containing:
  • Output at current timestep, shape (B, 1, D)
  • Updated cache dictionary

allocate_inference_cache

def allocate_inference_cache(
    self,
    batch_size: int,
    max_seqlen: int,
    dtype: Optional[torch.dtype] = None,
    **kwargs,
) -> Dict[str, Any]
Allocates cache for Mamba autoregressive inference.
batch_size
int
The batch size for inference.
max_seqlen
int
Maximum sequence length. Not used by Mamba but kept for interface consistency.
dtype
torch.dtype | None
default:"None"
Data type for allocated tensors. If None, uses model’s parameter dtype.
**kwargs
Any
Additional arguments (unused).
cache
Dict[str, Any]
Cache dictionary containing:
  • "conv_state": Convolution state, shape (B, D_inner, d_conv)
  • "lrnn_state": SSM state, shape (B, D_inner, N)
  • "seqlen_offset": Current position (initialized to 0)

Examples

Basic Usage

import torch
from lrnnx.models.ltv import Mamba

# Create Mamba model
model = Mamba(d_model=64, d_state=16, d_conv=4)

# Standard forward pass
x = torch.randn(2, 128, 64)
y = model(x)
print(y.shape)  # torch.Size([2, 128, 64])

Event-Based Processing

import torch
from lrnnx.models.ltv import Mamba

model = Mamba(d_model=64, d_state=16)

# Input with variable time intervals between events
x = torch.randn(2, 128, 64)
integration_timesteps = torch.rand(2, 128)  # Non-uniform timesteps

# Uses asymmetric discretization with separate dtA and dtB
y = model(x, integration_timesteps=integration_timesteps)

Autoregressive Inference

import torch
from lrnnx.models.ltv import Mamba

model = Mamba(d_model=64, d_state=16)
model.eval()

# Allocate cache
batch_size = 2
max_seq_len = 100
cache = model.allocate_inference_cache(
    batch_size=batch_size,
    max_seqlen=max_seq_len
)

# Generate sequence autoregressively
outputs = []
for t in range(max_seq_len):
    x_t = get_next_input(t)  # Shape: (2, 1, 64)
    y_t, cache = model.step(x_t, cache)
    outputs.append(y_t)

output = torch.cat(outputs, dim=1)  # Shape: (2, 100, 64)

Custom Discretization

import torch
from lrnnx.models.ltv import Mamba

# Use zero-order hold discretization instead of default Mamba
model = Mamba(
    d_model=64,
    d_state=16,
    discretization="zoh"  # Options: "mamba", "zoh", "bilinear", "dirac"
)

x = torch.randn(2, 128, 64)
y = model(x)

References

See Also

  • LTV_LRNN - Base class for LTV models
  • RGLRU - Alternative LTV architecture
  • S7 - Simplified state space model

Build docs developers (and LLMs) love