Skip to main content

Overview

The HierarchicalClassifier (also called Classifier) is a flexible architecture for sequence classification and regression built with Linear Recurrent Neural Networks. It supports variable-length sequences, hierarchical pooling across layers, and mixing different LRNN types. This architecture is inspired by the Event-SSM paper and is designed for tasks like:
  • Text classification
  • Sentiment analysis
  • Time series classification
  • Event-based vision tasks
  • Regression on sequential data

Architecture

The model consists of:
  • Input projection/embedding: Maps raw features or tokens to model dimension
  • Stack of LRNN blocks: Each block contains:
    • LRNN layer (LRU, S5, Centaurus, etc.)
    • Residual connection
    • Layer normalization
    • Dropout
    • Optional intermediate pooling to reduce sequence length
  • Final pooling: Aggregates sequence to single vector (last/mean/max)
  • Output head: Linear projection to class logits or regression values

Class Signature

Classifier(
    input_dim: int,
    num_classes: int = 0,
    output_dim: int = 1,
    d_model: int = 128,
    d_state: int = 64,
    n_layers: int = 4,
    lrnn_cls: Union[Type[nn.Module], List[Type[nn.Module]]] = LRU,
    pooling: Literal["mean", "last", "max"] = "last",
    dropout: float = 0.1,
    intermediate_pooling: Union[
        Literal["none", "stride", "mean", "max"],
        List[Literal["none", "stride", "mean", "max"]],
    ] = "none",
    pooling_factor: Union[int, List[int]] = 2,
    vocab_size: Optional[int] = None,
    embedding_dim: Optional[int] = None,
    max_position_embeddings: Optional[int] = None,
    padding_idx: Optional[int] = 0,
    lrnn_params: Optional[dict] = None,
)

Parameters

input_dim
int
required
Number of input features (ignored when vocab_size is provided for token embeddings)
num_classes
int
default:"0"
Number of output classes for classification. Set to 0 for regression tasks.
output_dim
int
default:"1"
Number of regression outputs (only used when num_classes=0)
d_model
int
default:"128"
Hidden dimension of the model
d_state
int
default:"64"
State dimension for the LRNN layers
n_layers
int
default:"4"
Number of LRNN layers
lrnn_cls
type | list[type]
default:"LRU"
LRNN class or list of classes (one per layer) to use. Can be:
  • A single class: LRU, S5, Centaurus
  • A string: "LRU", "S5", "Centaurus"
  • A list of classes/strings for heterogeneous layers: ["LRU", "S5", LRU, Centaurus]
pooling
str
default:"last"
Final pooling strategy for aggregating sequence to single vector:
  • "last": Use last timestep (respects variable lengths)
  • "mean": Average pooling over sequence
  • "max": Max pooling over sequence
dropout
float
default:"0.1"
Dropout probability for regularization
intermediate_pooling
str | list[str]
default:"none"
Pooling strategy for reducing sequence length within layers:
  • "none": No intermediate pooling
  • "stride": Strided selection (every k-th element)
  • "mean": Average pooling
  • "max": Max pooling
Can be a single string or a list (one per layer) for layer-specific pooling.
pooling_factor
int | list[int]
default:"2"
Factor by which to reduce sequence length at each layer with intermediate pooling. Can be a single int or list (one per layer).
vocab_size
int
default:"None"
Size of vocabulary for token embeddings. When provided, the model expects token IDs as input instead of continuous features.
embedding_dim
int
default:"None"
Dimension of token embeddings (defaults to d_model if not specified)
max_position_embeddings
int
default:"None"
Maximum sequence length for positional embeddings
padding_idx
int
default:"0"
Index of padding token for embedding layer
lrnn_params
dict
default:"None"
Additional parameters passed to LRNN constructors. Example: {"d_model": 128, "d_state": 64}

Usage Examples

Text Classification

import torch
from lrnnx.architectures.classifier import Classifier

# Binary sentiment classification
model = Classifier(
    input_dim=0,  # Ignored when using embeddings
    num_classes=2,  # Positive/negative
    d_model=256,
    d_state=128,
    n_layers=4,
    lrnn_cls="S5",
    pooling="mean",  # Average over sequence
    vocab_size=10000,  # 10k vocabulary
    dropout=0.2,
)

# Input: token IDs (batch_size, seq_len)
input_ids = torch.randint(0, 10000, (8, 128))
logits = model(input_ids)  # (8, 2)
predictions = logits.argmax(dim=-1)  # (8,)

Time Series Classification with Variable Lengths

# Event classification with variable-length sequences
model = Classifier(
    input_dim=64,  # 64 sensor features
    num_classes=10,  # 10 event types
    d_model=128,
    d_state=64,
    n_layers=3,
    lrnn_cls="LRU",
    pooling="last",  # Use last valid timestep
)

# Input: continuous features (batch_size, seq_len, features)
x = torch.randn(4, 200, 64)
lengths = torch.tensor([150, 180, 120, 200])  # Actual lengths

logits = model(x, lengths=lengths)  # (4, 10)

Hierarchical Pooling for Long Sequences

# Reduce sequence length progressively through layers
model = Classifier(
    input_dim=32,
    num_classes=5,
    d_model=128,
    d_state=64,
    n_layers=4,
    lrnn_cls="S5",
    pooling="mean",
    # Downsample at each layer: 1000 -> 500 -> 250 -> 125 -> 62
    intermediate_pooling=["stride", "stride", "stride", "none"],
    pooling_factor=2,
)

x = torch.randn(2, 1000, 32)  # Very long sequences
logits = model(x)  # (2, 5)

Regression Task

# Predict 3 continuous values from sequence
model = Classifier(
    input_dim=16,
    num_classes=0,  # 0 = regression mode
    output_dim=3,   # 3 output values
    d_model=64,
    d_state=32,
    n_layers=2,
    lrnn_cls="LRU",
    pooling="last",
)

x = torch.randn(8, 50, 16)
predictions = model(x)  # (8, 3)

Heterogeneous LRNN Layers

from lrnnx.models.lti.lru import LRU
from lrnnx.models.lti.s5 import S5
from lrnnx.models.lti.centaurus import Centaurus

# Different LRNN type for each layer
model = Classifier(
    input_dim=64,
    num_classes=10,
    d_model=128,
    d_state=64,
    n_layers=4,
    lrnn_cls=[LRU, S5, S5, Centaurus],  # Mix different types
    pooling="mean",
    lrnn_params={"d_model": 128, "d_state": 64},
)

Methods

forward

forward(
    x: Tensor,
    lengths: Optional[Tensor] = None,
    integration_timesteps: Optional[Tensor] = None,
) -> Tensor
Forward pass of the classifier/regressor. Arguments:
  • x (torch.Tensor): Input tensor
    • Token IDs of shape (B, L) when using embeddings
    • Continuous features of shape (B, L, input_dim) otherwise
  • lengths (torch.Tensor, optional): Actual sequence lengths of shape (B,) for variable-length sequences
  • integration_timesteps (torch.Tensor, optional): Timesteps of shape (B, L) for LTV models
Returns:
  • torch.Tensor:
    • Classification logits of shape (B, num_classes) when num_classes > 0
    • Regression values of shape (B, output_dim) when num_classes = 0

Use Cases

  • Text classification: Sentiment analysis, topic classification, spam detection
  • Event-based vision: Event camera classification tasks
  • Time series classification: Activity recognition, anomaly detection
  • Sequence regression: Predicting continuous values from sequential data
  • Audio classification: Speaker recognition, audio event detection

References

See Also

Build docs developers (and LLMs) love