Skip to main content

Overview

This guide shows you how to train Linear RNN models using lrnnx. The library provides easy-to-use training APIs for both time-invariant (LTI) and time-varying (LTV) models.

Quick Start

1

Import models

Import the model architectures from lrnnx:
from lrnnx.models.lti import LRU, S4, S4D, S5
2

Instantiate model

Create a model instance in training mode:
import torch
from lrnnx.models.lti import LRU

# Model parameters
d_model = 64      # Model dimension
d_state = 64      # State dimension

# Create model on CUDA
model = LRU(d_model=d_model, d_state=d_state).cuda()
model.train()  # Set to training mode
3

Create input tensors

Prepare your training data:
batch_size = 32
seq_len = 128
d_model = 64

# Create input tensor (B, L, H)
x = torch.randn(
    batch_size, seq_len, d_model,
    dtype=torch.float32,
    device="cuda"
)
All lrnnx models expect input of shape (batch_size, seq_len, d_model) where:
  • batch_size: Number of sequences in the batch
  • seq_len: Length of each sequence
  • d_model: Feature dimension
4

Forward and backward pass

Run the training loop:
import torch.nn as nn
import torch.optim as optim

# Define loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(num_epochs):
    # Forward pass
    output = model(x)  # Shape: (batch_size, seq_len, d_model)

    # Compute loss (example: reconstruction)
    loss = criterion(output, x)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Complete Training Example

Here’s a complete example showing forward and backward passes:
import torch
import torch.nn as nn
import torch.optim as optim
from lrnnx.models.lti import LRU

# Model configuration
d_model = 64
d_state = 64
batch_size = 32
seq_len = 128

# Initialize model
model = LRU(d_model=d_model, d_state=d_state).cuda()
model.train()

# Create sample data
x = torch.randn(batch_size, seq_len, d_model, device="cuda")
target = torch.randn(batch_size, seq_len, d_model, device="cuda")

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Single training step
optimizer.zero_grad()
output = model(x)
loss = criterion(output, target)
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}")

Event-Based Training

Some models (S5, S6, Mamba) support event-based processing with custom integration timesteps:
from lrnnx.models.lti import S5

model = S5(d_model=64, d_state=64).cuda()
model.train()

x = torch.randn(32, 128, 64, device="cuda")

# Provide integration timesteps (B, L)
integration_timesteps = torch.rand(32, 128, device="cuda")

# Forward pass with custom timesteps
output = model(x, integration_timesteps=integration_timesteps)
When using integration_timesteps, ensure they are positive values representing the time intervals between events.

Benchmarking Training Performance

The library includes built-in benchmarking utilities to measure training throughput:
from benchmarks.benchmark_training import benchmark_sequence_length

def model_fn():
    return LRU(d_model=64, d_state=64).cuda()

# Benchmark across different sequence lengths
results = benchmark_sequence_length(
    model_fn,
    seq_lengths=[128, 256, 512, 1024, 2048],
    batch_size=32,
    repeats=5
)

for seq_len, times in results.items():
    avg_time = sum(times) / len(times)
    print(f"Seq len {seq_len}: {avg_time:.2f} ms")
See the full benchmarking suite in benchmarks/benchmark_training.py for more examples including:
  • Varying model dimensions
  • Varying batch sizes
  • Multi-run statistics

Mixed Precision Training

For faster training with lower memory usage, use automatic mixed precision:
import torch
from torch.cuda.amp import autocast, GradScaler
from lrnnx.models.ltv import Mamba

model = Mamba(d_model=64, d_state=16).cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()

x = torch.randn(32, 128, 64, device="cuda")
target = torch.randn(32, 128, 64, device="cuda")

for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Forward pass with autocast
    with autocast():
        output = model(x)
        loss = nn.functional.mse_loss(output, target)

    # Backward pass with gradient scaling
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Available Models

All models follow the same training interface:

LTI Models (Time-Invariant)

  • S4 - Structured State Space model
  • S4D - Diagonal variant of S4
  • S5 - Simplified State Space model
  • LRU - Linear Recurrent Unit

LTV Models (Time-Varying)

  • Mamba - Selective State Space model
  • S6/S7 - Extensions with selective mechanisms
  • RGLRU - Recurrent Gated Linear Recurrent Unit

Next Steps

Inference Guide

Learn about fast autoregressive generation with CUDA graphs

Custom Kernels

Understand the CUDA kernels powering lrnnx

Build docs developers (and LLMs) love