Skip to main content

Overview

Centaurus is an advanced state space model that introduces intra-state mode mixing through sub-states. Unlike traditional SSMs with a single state dimension, Centaurus decomposes each state into multiple sub-states that interact through a mixing matrix. Centaurus supports four different modes (neck, DWS, full, pointwise) that control how input and output channels interact with the state space, allowing flexible architecture design.

Paper Reference

Centaurus: Let SSMs be Conv Nets https://openreview.net/forum?id=PkpNRmBZ32

Installation

from lrnnx.models.lti import Centaurus, CentaurusNeck, CentaurusDWS, CentaurusFull, CentaurusPWNeck

Parameters

d_model
int
required
Model dimension - size of input and output features.
d_state
int
required
Number of state channels. Each channel has sub_state_dim sub-states.
sub_state_dim
int
required
Number of sub-states per state channel. Controls intra-state mixing capacity.
discretization
Literal['zoh', 'bilinear', 'dirac', 'async']
default:"'zoh'"
Discretization method. Currently only ZOH is fully supported.
mode
Literal['neck', 'pointwise', 'pw', 's5', 'dws', 'full']
default:"'neck'"
Architecture mode controlling channel interaction:
  • 'neck': Bottleneck with dense projections
  • 'dws': Depthwise-separable (one state per channel)
  • 'full': Fully connected (state per input-output pair)
  • 'pointwise' / 'pw' / 's5': Pointwise bottleneck (flattened sub-states)

Usage Example

Basic Usage

import torch
from lrnnx.models.lti import Centaurus

# Create Centaurus with neck mode
model = Centaurus(
    d_model=64,
    d_state=64,
    sub_state_dim=8,
    mode="neck"
)

# Forward pass
x = torch.randn(2, 128, 64)  # (batch, length, features)
y = model(x)

print(y.shape)  # torch.Size([2, 128, 64])

Different Modes

import torch
from lrnnx.models.lti import (
    CentaurusNeck,
    CentaurusDWS,
    CentaurusFull,
    CentaurusPWNeck
)

# Bottleneck mode (default)
model_neck = CentaurusNeck(
    d_model=64,
    d_state=32,
    sub_state_dim=8,
)

# Depthwise-separable mode
model_dws = CentaurusDWS(
    d_model=64,
    d_state=64,  # Must match d_model for DWS
    sub_state_dim=8,
)

# Fully connected mode
model_full = CentaurusFull(
    d_model=64,
    d_state=4096,  # d_model^2 states
    sub_state_dim=8,
)

# Pointwise bottleneck mode
model_pw = CentaurusPWNeck(
    d_model=64,
    d_state=32,
    sub_state_dim=8,
)

x = torch.randn(2, 128, 64)
y_neck = model_neck(x)
y_dws = model_dws(x)
y_full = model_full(x)
y_pw = model_pw(x)

Using the Wrapper

import torch
from lrnnx.models.lti import Centaurus

# All modes accessible via wrapper
model = Centaurus(
    d_model=64,
    d_state=64,
    sub_state_dim=8,
    mode="dws",  # Specify mode
    discretization="zoh"
)

x = torch.randn(2, 128, 64)
y = model(x)

Autoregressive Inference

import torch
from lrnnx.models.lti import CentaurusNeck

model = CentaurusNeck(
    d_model=64,
    d_state=64,
    sub_state_dim=8,
)

batch_size = 2
cache = model.allocate_inference_cache(batch_size=batch_size)

# Generate sequence step-by-step
for t in range(100):
    x_t = torch.randn(batch_size, 64)
    y_t, cache = model.step(x_t, cache)
    # y_t.shape: (batch_size, 64)

Architecture Modes

Neck Mode (Bottleneck)

Use: General-purpose, balanced performance
  • Dense input projection: B: (d_state, d_model)
  • Dense output projection: C: (d_model, d_state)
  • Bottleneck through state dimension
model = CentaurusNeck(d_model=256, d_state=64, sub_state_dim=8)

DWS Mode (Depthwise-Separable)

Use: Parameter efficiency, one state per channel
  • Diagonal projections: one state per input/output channel
  • d_state must equal d_model
  • Minimal parameters
model = CentaurusDWS(d_model=128, d_state=128, sub_state_dim=8)

Full Mode (Fully Connected)

Use: Maximum expressiveness, state per connection
  • State for each (input, output) pair
  • d_state = d_model * d_model
  • Highest capacity but most parameters
model = CentaurusFull(d_model=32, d_state=1024, sub_state_dim=8)

Pointwise Mode (Flattened Sub-states)

Use: Simplified variant, no E-mixing
  • Flattens (d_state, sub_state_dim) into single dimension
  • No mixing matrix E
  • Delta shared across sub-states
model = CentaurusPWNeck(d_model=64, d_state=32, sub_state_dim=8)

Key Features

Intra-State Mixing

Centaurus’s key innovation is sub-state decomposition:
# Traditional SSM
state: (d_state,)

# Centaurus (most modes)
state: (d_state, sub_state_dim)

# Mixing via matrix E
mixed_state = einsum('nm,bnm->bn', E, state.real)
This allows the model to learn temporal patterns at multiple scales within each state channel.

ZOH Discretization

Centaurus uses implicit Zero-Order Hold discretization:
# Compute discrete kernel
dtA = delta[:, None] * A  # (d_state, sub_state_dim)
K = einsum('nm,l->nml', exp(dtA), time_range)
mixed_K = einsum('nml,nm->nl', K.real, E)

Learned Time Scales

Each state channel has a learnable delta (timestep):
delta = exp(log_delta)  # (d_state,)
Initialized log-spaced from 0.001 to 0.1.

State Representation

Most Modes (Neck, DWS, Full)

state: (batch_size, d_state, sub_state_dim) dtype=complex
  • Each of d_state channels has sub_state_dim sub-states
  • Complex-valued for frequency modeling
  • Mixed via matrix E before output

Pointwise Mode

state: (batch_size, d_state * sub_state_dim) dtype=complex
  • Flattened representation
  • No E-mixing
  • Simpler but less structured

Parameter Count Comparison

ModeParametersUse Case
DWSMinimalEfficiency-critical
NeckModerateGeneral purpose
PointwiseModerateSimplified variant
FullMaximumExpressiveness-critical
Formula:
  • Neck: ~d_state * d_model * 2 + smaller terms
  • DWS: ~2 * d_model (when d_state=d_model)
  • Full: ~d_state (but d_state = d_model^2)
  • Pointwise: ~(d_state * sub_state_dim) * d_model * 2

Performance Tips

Start with neck mode for general tasks. It offers the best balance of performance and parameter efficiency.
The sub_state_dim parameter controls the multi-scale capacity. Values of 4-16 typically work well.
Currently, only discretization="zoh" is fully supported. Other methods are experimental.

When to Use Each Mode

Neck Mode

✅ General-purpose tasks
✅ Balanced performance/efficiency
✅ When d_state < d_model (bottleneck)

DWS Mode

✅ Parameter efficiency critical
✅ Depthwise processing sufficient
✅ When d_model is moderate

Full Mode

✅ Maximum expressiveness needed
✅ Small d_model (e.g., 16-32)
✅ Complex input-output relationships

Pointwise Mode

✅ Simplified implementation
✅ When E-mixing not needed
✅ Compatibility with S5-style code

Advanced Usage

Multi-Scale Architecture

import torch.nn as nn
from lrnnx.models.lti import CentaurusNeck

class MultiScaleCentaurus(nn.Module):
    def __init__(self, d_model=128):
        super().__init__()
        # Coarse scale (few sub-states)
        self.coarse = CentaurusNeck(
            d_model=d_model,
            d_state=64,
            sub_state_dim=4,
        )
        # Fine scale (many sub-states)
        self.fine = CentaurusNeck(
            d_model=d_model,
            d_state=64,
            sub_state_dim=16,
        )
        self.combine = nn.Linear(d_model * 2, d_model)
    
    def forward(self, x):
        y_coarse = self.coarse(x)
        y_fine = self.fine(x)
        return self.combine(torch.cat([y_coarse, y_fine], dim=-1))

Hybrid DWS + Neck

import torch.nn as nn
from lrnnx.models.lti import CentaurusDWS, CentaurusNeck

class HybridCentaurus(nn.Module):
    def __init__(self, d_model=128):
        super().__init__()
        # Depthwise for efficiency
        self.dws = CentaurusDWS(
            d_model=d_model,
            d_state=d_model,
            sub_state_dim=8,
        )
        # Neck for mixing
        self.neck = CentaurusNeck(
            d_model=d_model,
            d_state=64,
            sub_state_dim=8,
        )
    
    def forward(self, x):
        x = self.dws(x) + x
        x = self.neck(x) + x
        return x

See Also

  • S5 - Simpler SSM baseline
  • LRU - Minimal diagonal SSM
  • Mamba - Input-dependent selective model

Build docs developers (and LLMs) love