Classifier architecture from lrnnx on the ListOps task, showing how to build both single-model and hybrid multi-model classifiers.
Overview
You’ll learn how to:- Load and prepare the ListOps dataset
- Build a classifier with LRU layers
- Train and evaluate the model
- Create hybrid architectures mixing different LRNN models
Data Flow
-
Input & tokenization: Raw ListOps expressions are tokenized into IDs (digits, brackets, operators) and padded/truncated to fixed length. Shape:
(B, L, E)whereB= batch,L= sequence length,E= embedding size - Embedding: Token IDs are mapped to continuous vectors via an embedding layer
- Per-block processing: The Classifier is a stack of blocks. Each block runs an LRNN mixer over sequence positions, then applies residual connection, LayerNorm, and dropout
-
Intermediate pooling (optional): Between blocks, the model can apply pooling (stride/mean/max) to aggregate nearby tokens, shortening sequence length
L -> L' -
Hybrid stacks: The
lrnn_clsargument accepts either a single LRNN class for all blocks or a list of LRNN classes to construct heterogeneous layers (e.g.,[LRU, S5, Centaurus]) -
Final collapse & head: After the final block, the sequence is collapsed to a single vector per example using pooling mode (last/mean/max), producing
(B, E). A linear classification head maps this to logits(B, num_classes)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from tqdm import tqdm
from lrnnx.architectures.classifier import Classifier
from lrnnx.models.lti.lru import LRU
# General
BATCH_SIZE = 32
NUM_EPOCHS = 10
# Single-model hyperparameters
EMBEDDING_DIM = 128
D_MODEL = 512
D_STATE = 256
N_LAYERS = 6
We use the ListOps-1000 dataset from Hugging Face, which contains mathematical expressions with nested operations (MAX, MIN, MED, SUM):
print("Loading ListOps dataset from Hugging Face...")
ds = load_dataset("fengyang0317/listops-1000")
print("Dataset loaded!")
# Create a subset for faster experimentation
subset_fraction = 0.01
ds_subset = {}
for split in ds.keys():
total_examples = len(ds[split])
subset_size = int(total_examples * subset_fraction)
indices = np.random.choice(total_examples, subset_size, replace=False)
ds_subset[split] = ds[split].select(indices)
print(f"Train subset size: {len(ds_subset['train'])}")
print(f"Validation subset size: {len(ds_subset['validation'])}")
print(f"Test subset size: {len(ds_subset['test'])}")
[PAD][]MAXMINMEDSUM0-9LISTOPS_VOCAB_SIZE = 20
LISTOPS_NUM_CLASSES = 10 # Output classes 0-9
LISTOPS_MAX_LEN = 200
class ListOpsTokenizer:
def __init__(self):
self.token_map = {
"[PAD]": 0,
"[": 1,
"]": 2,
"MAX": 3,
"MIN": 4,
"MED": 5,
"SUM": 6
}
# Add digits to token map
for i in range(10):
self.token_map[str(i)] = 7 + i
def tokenize(self, sequence, max_length):
"""Convert a ListOps sequence to token IDs"""
tokens = []
for token in sequence.split():
if token in self.token_map:
tokens.append(self.token_map[token])
seq_len = len(tokens)
# Pad or truncate
if len(tokens) > max_length:
tokens = tokens[:max_length]
seq_len = max_length
else:
tokens = tokens + [0] * (max_length - len(tokens))
return tokens, seq_len
class HFListOpsDataset(Dataset):
def __init__(self, hf_dataset, split, max_length=LISTOPS_MAX_LEN):
self.data = hf_dataset[split]
self.tokenizer = ListOpsTokenizer()
self.max_length = max_length
self.tokens = []
self.labels = []
self.lengths = []
print(f"Processing {split} dataset...")
for example in tqdm(self.data):
tokens, seq_len = self.tokenizer.tokenize(
example['Source'], self.max_length
)
self.tokens.append(tokens)
self.labels.append(int(example['Target']))
self.lengths.append(seq_len)
print(f"Processed {len(self.labels)} examples.")
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return (
torch.tensor(self.tokens[idx], dtype=torch.long),
torch.tensor(self.labels[idx], dtype=torch.long),
torch.tensor(self.lengths[idx], dtype=torch.long)
)
# Create datasets and dataloaders
train_dataset = HFListOpsDataset(ds_subset, 'train', max_length=LISTOPS_MAX_LEN)
val_dataset = HFListOpsDataset(ds_subset, 'validation', max_length=LISTOPS_MAX_LEN)
test_dataset = HFListOpsDataset(ds_subset, 'test', max_length=LISTOPS_MAX_LEN)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
When instantiating a
Classifier, the lrnn_params argument must contain the exact constructor arguments for your chosen lrnn_cls:lrnn_params{"H": d_model, "N": d_state}{"hid_dim": d_model, "state_dim": d_state, "discretization": "zoh"}{"d_model": d_model, "state_dim": d_state, "sub_state_dim": d_state}device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create LRNN classifier model
model = Classifier(
input_dim=None, # Not needed with embeddings
num_classes=LISTOPS_NUM_CLASSES,
vocab_size=LISTOPS_VOCAB_SIZE,
embedding_dim=EMBEDDING_DIM,
max_position_embeddings=LISTOPS_MAX_LEN,
padding_idx=0,
d_model=D_MODEL,
d_state=D_STATE,
n_layers=N_LAYERS,
lrnn_cls=LRU,
lrnn_params={"H": D_MODEL, "N": D_STATE}, # LRU constructor args
dropout=0.1,
pooling="last",
intermediate_pooling=["none"] * N_LAYERS,
).to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1, betas=(0.9, 0.98))
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
# Training loop
best_acc = 0.0
train_losses = []
val_accs = []
for epoch in range(NUM_EPOCHS):
model.train()
running_loss = 0.0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
for tokens, labels, lengths in progress_bar:
tokens = tokens.to(device)
labels = labels.to(device)
lengths = lengths.to(device)
optimizer.zero_grad()
outputs = model(tokens, lengths=lengths)
loss = criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
running_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
epoch_loss = running_loss / len(train_loader)
train_losses.append(epoch_loss)
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for tokens, labels, lengths in val_loader:
tokens = tokens.to(device)
labels = labels.to(device)
lengths = lengths.to(device)
outputs = model(tokens, lengths=lengths)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_acc = 100.0 * correct / total
val_accs.append(val_acc)
print(f"Epoch {epoch+1} - Loss: {epoch_loss:.4f}, Val Acc: {val_acc:.2f}%")
# Save best model
if val_acc > best_acc:
best_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_acc
}, "best_listops_model.pt")
print(f"New best model saved with accuracy: {best_acc:.2f}%")
scheduler.step()
print(f"Training complete! Best validation accuracy: {best_acc:.2f}%")
print("Evaluating on test set...")
model.eval()
correct = 0
total = 0
with torch.no_grad():
for tokens, labels, lengths in tqdm(test_loader):
tokens = tokens.to(device)
labels = labels.to(device)
lengths = lengths.to(device)
outputs = model(tokens, lengths=lengths)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_acc = 100.0 * correct / total
print(f"Test accuracy: {test_acc:.2f}%")
Advanced: Multi-Model Architecture
TheClassifier supports heterogeneous architectures — you can use different LRNN models in different layers for hybrid designs!
How It Works
Example: S5 Multi-Layer Classifier
Conclusion
You’ve learned how to:- Build hierarchical classifiers using LRNN models
- Process sequential data with tokenization and embeddings
- Train and evaluate models on the ListOps task
- Create hybrid architectures mixing different LRNN types
- Increase
subset_fractionto use more training data - Experiment with different model sizes and layer counts
- Try hybrid architectures mixing LRU, S5, and Centaurus
- Tune hyperparameters like learning rate and dropout
