Skip to main content

Overview

Classifier is a flexible architecture for sequence classification and regression tasks. It supports multiple LRNN layer types, variable-length sequences, token embeddings, and hierarchical pooling strategies.

Class Definition

from lrnnx.architectures import Classifier
from lrnnx.models.lti.lru import LRU

model = Classifier(
    input_dim=128,
    num_classes=10,
    d_model=256,
    d_state=64,
    n_layers=4,
    lrnn_cls=LRU,
    pooling="last"
)

Parameters

input_dim
int
required
Number of input features. Ignored when vocab_size is provided (for token-based inputs).
num_classes
int
default:"0"
Number of output classes for classification. Set to 0 for regression tasks.
output_dim
int
default:"1"
Number of regression outputs (used when num_classes=0).
d_model
int
default:"128"
Hidden dimension of the model.
d_state
int
default:"64"
State dimension for the LRNN layers.
n_layers
int
default:"4"
Number of LRNN processing blocks.
lrnn_cls
type | list[type]
default:"LRU"
LRNN class or list of classes (one per layer). Can be a class object or string name.Available options:
  • LRU or "LRU" - Linear Recurrent Unit
  • S5 or "S5" - Simplified State Space
  • Centaurus or "Centaurus" - Centaurus mixer
Example: [LRU, S5, LRU, S5] or ["LRU", "S5", "LRU", "S5"]
pooling
str
default:"last"
Final pooling strategy for sequence outputs.Options:
  • "last" - Use the last timestep (respects lengths if provided)
  • "mean" - Average over all timesteps
  • "max" - Max pooling over all timesteps
dropout
float
default:"0.1"
Dropout probability applied after each LRNN layer.
intermediate_pooling
str | list[str]
default:"none"
Pooling strategy for each layer to reduce sequence length.Options:
  • "none" - No intermediate pooling
  • "stride" - Strided sampling
  • "mean" - Average pooling
  • "max" - Max pooling
Can be a single string (applied to all layers) or a list of length n_layers.
pooling_factor
int | list[int]
default:"2"
Factor by which to reduce sequence length at each layer (when using intermediate pooling).Can be a single integer (same for all layers) or a list of length n_layers.
vocab_size
int
default:"None"
Size of vocabulary for token embeddings. When provided, the model expects token IDs as input instead of continuous features.
embedding_dim
int
default:"None"
Dimension of token embeddings. Defaults to d_model if not specified.
max_position_embeddings
int
default:"None"
Maximum sequence length for learned positional embeddings. Only used if positional embeddings are enabled in the embedding layer.
padding_idx
int
default:"0"
Index of the padding token in the vocabulary (when using token embeddings).
lrnn_params
dict
default:"None"
Additional parameters passed to LRNN module constructors.Must include all required constructor arguments for the LRNN class.Examples:
  • LRU: {"d_model": 256, "d_state": 64}
  • S5: {"d_model": 256, "d_state": 64, "discretization": "zoh"}

Methods

forward

output = model.forward(
    x,
    lengths=None,
    integration_timesteps=None
)
Forward pass of the classifier/regressor.
x
torch.Tensor
required
Input tensor:
  • Token IDs of shape (B, L) when using embeddings
  • Continuous features of shape (B, L, input_dim) otherwise
lengths
torch.Tensor
default:"None"
Actual sequence lengths of shape (B,) for variable-length sequences.
integration_timesteps
torch.Tensor
default:"None"
Timesteps of shape (B, L) for LTV models (e.g., Mamba).
output
torch.Tensor
  • Classification: Logits of shape (B, num_classes)
  • Regression: Values of shape (B, output_dim)

Example Usage

Classification with Continuous Features

import torch
from lrnnx.architectures import Classifier
from lrnnx.models.lti.lru import LRU

# Create classifier
model = Classifier(
    input_dim=128,
    num_classes=10,
    d_model=256,
    d_state=64,
    n_layers=4,
    lrnn_cls=LRU,
    pooling="last",
    lrnn_params={"d_model": 256, "d_state": 64}
).cuda()

# Input: batch=32, seq_len=100, features=128
x = torch.randn(32, 100, 128).cuda()
lengths = torch.randint(50, 100, (32,)).cuda()

# Forward pass
logits = model(x, lengths=lengths)
print(logits.shape)  # (32, 10)

Classification with Token Embeddings

import torch
from lrnnx.architectures import Classifier

# Create classifier with token embeddings
model = Classifier(
    input_dim=0,  # Ignored when vocab_size is set
    num_classes=5,
    d_model=256,
    d_state=64,
    n_layers=4,
    vocab_size=10000,
    embedding_dim=256,
    pooling="mean",
    lrnn_params={"d_model": 256, "d_state": 64}
).cuda()

# Input: batch=16, seq_len=50 (token IDs)
token_ids = torch.randint(0, 10000, (16, 50)).cuda()

# Forward pass
logits = model(token_ids)
print(logits.shape)  # (16, 5)

Regression Task

import torch
from lrnnx.architectures import Classifier
from lrnnx.models.lti.s5 import S5

# Create regressor
model = Classifier(
    input_dim=64,
    num_classes=0,  # 0 for regression
    output_dim=3,   # Predict 3 values
    d_model=128,
    d_state=32,
    n_layers=3,
    lrnn_cls=S5,
    pooling="mean",
    lrnn_params={"d_model": 128, "d_state": 32, "discretization": "zoh"}
).cuda()

# Input
x = torch.randn(8, 200, 64).cuda()

# Forward pass
output = model(x)
print(output.shape)  # (8, 3)

Hierarchical Pooling

import torch
from lrnnx.architectures import Classifier

# Use intermediate pooling to reduce sequence length
model = Classifier(
    input_dim=128,
    num_classes=10,
    d_model=256,
    d_state=64,
    n_layers=4,
    intermediate_pooling=["none", "stride", "stride", "none"],
    pooling_factor=[1, 2, 2, 1],  # Reduce by 4x total
    pooling="last",
    lrnn_params={"d_model": 256, "d_state": 64}
).cuda()

# Input: very long sequence
x = torch.randn(4, 1000, 128).cuda()
logits = model(x)
print(logits.shape)  # (4, 10)

Mixed LRNN Types

import torch
from lrnnx.architectures import Classifier
from lrnnx.models.lti.lru import LRU
from lrnnx.models.lti.s5 import S5

# Use different LRNN types per layer
model = Classifier(
    input_dim=64,
    num_classes=5,
    d_model=128,
    d_state=64,
    n_layers=4,
    lrnn_cls=[LRU, S5, LRU, S5],  # Alternate between LRU and S5
    pooling="mean",
    lrnn_params={"d_model": 128, "d_state": 64}
).cuda()

x = torch.randn(8, 100, 64).cuda()
logits = model(x)
print(logits.shape)  # (8, 5)

Notes

  • Set num_classes > 0 for classification tasks, num_classes=0 for regression
  • The lengths parameter enables proper handling of variable-length sequences
  • Intermediate pooling can significantly reduce computation for long sequences
  • When using token embeddings, the model automatically handles the embedding and projection layers
  • The lrnn_params dict must contain all required parameters for the chosen LRNN class

Build docs developers (and LLMs) love