Skip to main content

Overview

LTV_LRNN is the abstract base class for all Linear Time-Varying (LTV) LRNN models in lrnnx. LTV models have time-varying dynamics where state transition matrices (A, B, C, etc.) can change at each timestep based on the input. Key characteristics:
  • Cannot use FFT-based convolution for training (since kernels vary per timestep)
  • Support async/event-driven discretization with variable timesteps
  • Must use scan for both training and inference
  • More expressive than LTI models but computationally more expensive

Class Hierarchy

LRNN (lrnnx.core.base)
  └── LTV_LRNN (lrnnx.models.ltv.base)
       ├── Mamba
       ├── RGLRU
       └── S7

Import

from lrnnx.models.ltv.base import LTV_LRNN

Class Signature

class LTV_LRNN(LRNN)

Constructor

__init__

def __init__(
    self,
    discretization: Optional[
        Literal["zoh", "bilinear", "dirac", "async", "no_discretization"]
    ],
)
Initialize the LTV LRNN base class.
discretization
Literal['zoh', 'bilinear', 'dirac', 'async', 'no_discretization'] | None
Discretization method to use. Can be one of:
  • "zoh": Zero-order hold discretization
  • "bilinear": Bilinear transform discretization
  • "dirac": Dirac delta discretization
  • "async": Asynchronous/event-driven discretization
  • "no_discretization": No discretization applied
  • None: For models that handle discretization internally

Abstract Methods

forward

@abstractmethod
def forward(
    self,
    x: Tensor,
    integration_timesteps: Optional[Tensor] = None,
    lengths: Optional[Tensor] = None,
    inference_cache: Optional[Dict[str, Any]] = None,
) -> Tensor
Forward pass through the LTV model. Must be implemented by subclasses.
x
torch.Tensor
Input tensor of shape (B, L, H) where:
  • B = batch size
  • L = sequence length
  • H = hidden dimension
integration_timesteps
torch.Tensor | None
default:"None"
Timesteps for async/event-driven discretization, shape (B, L). When provided, enables event-based processing where the time intervals between events may vary. If None, uniform timesteps are assumed. See Event-based State Space Models for more details.
lengths
torch.Tensor | None
default:"None"
Lengths of sequences in the batch, shape (B,). Required for variable-length sequences or bidirectional models.
inference_cache
Dict[str, Any] | None
default:"None"
Cache containing states and pre-computed values for efficient autoregressive generation. If provided during inference, enables incremental processing. Use allocate_inference_cache() to create.
output
torch.Tensor
Output tensor of shape (B, L, H), same shape as input.

step

@abstractmethod
def step(
    self,
    x: Tensor,
    inference_cache: Dict[str, Any],
    **kwargs,
) -> Tuple[Tensor, Dict[str, Any]]
Performs a single recurrent step of the LTV model for autoregressive inference. This method processes inputs one timestep at a time. Unlike LTI models, the dynamics may vary at each step based on the input.
x
torch.Tensor
Input at current timestep, shape (B, 1, H).
inference_cache
Dict[str, Any]
Cache dictionary containing model states. This is the same format returned by allocate_inference_cache(). The cache is updated in-place and also returned for convenience.
**kwargs
Any
Additional keyword arguments specific to the model implementation.
output
Tuple[torch.Tensor, Dict[str, Any]]
A tuple containing:
  • y: Output at current timestep, shape (B, 1, H)
  • inference_cache: Updated cache dictionary

allocate_inference_cache

@abstractmethod
def allocate_inference_cache(
    self,
    batch_size: int,
    max_seqlen: int,
    dtype: Optional[torch.dtype] = None,
    **kwargs,
) -> Dict[str, Any]
Allocates cache for efficient autoregressive inference. For LTV models, this typically includes:
  • Initial hidden state(s)
  • Any auxiliary states (e.g., convolution state for Mamba)
  • Metadata for tracking sequence position
batch_size
int
The batch size for inference.
max_seqlen
int
Maximum sequence length to support.
dtype
torch.dtype | None
default:"None"
Data type for allocated tensors. If None, uses the model’s default dtype.
**kwargs
Any
Additional model-specific arguments.
cache
Dict[str, Any]
Cache dictionary that can be passed to forward(). Should contain at minimum:
  • Model state tensors (e.g., "lrnn_state", "conv_state")
  • "seqlen_offset": Current position in the sequence

Example Usage

# LTV_LRNN is abstract, use a concrete implementation
from lrnnx.models.ltv import Mamba

# Create model
model = Mamba(d_model=64, d_state=16)

# Standard forward pass
x = torch.randn(2, 128, 64)
y = model(x)

# Event-based processing with variable timesteps
integration_timesteps = torch.rand(2, 128)  # Variable time intervals
y = model(x, integration_timesteps=integration_timesteps)

# Autoregressive inference
cache = model.allocate_inference_cache(batch_size=2, max_seqlen=128)
for t in range(seq_len):
    x_t = x[:, t:t+1, :]  # Shape: (2, 1, 64)
    y_t, cache = model.step(x_t, cache)

See Also

  • Mamba - Selective State Space Model implementation
  • RGLRU - Recurrent Gated Linear Recurrent Unit
  • S7 - Selective and Simplified State Space Layer
  • LRNN Base Class - Parent class for all LRNN models

Build docs developers (and LLMs) love