Overview
This guide shows you how to train Linear RNN models using lrnnx. The library provides easy-to-use training APIs for both time-invariant (LTI) and time-varying (LTV) models.
Quick Start
Import models
Import the model architectures from lrnnx: from lrnnx.models.lti import LRU , S4, S4D , S5
Instantiate model
Create a model instance in training mode: import torch
from lrnnx.models.lti import LRU
# Model parameters
d_model = 64 # Model dimension
d_state = 64 # State dimension
# Create model on CUDA
model = LRU( d_model = d_model, d_state = d_state).cuda()
model.train() # Set to training mode
Create input tensors
Prepare your training data: batch_size = 32
seq_len = 128
d_model = 64
# Create input tensor (B, L, H)
x = torch.randn(
batch_size, seq_len, d_model,
dtype = torch.float32,
device = "cuda"
)
All lrnnx models expect input of shape (batch_size, seq_len, d_model) where:
batch_size: Number of sequences in the batch
seq_len: Length of each sequence
d_model: Feature dimension
Forward and backward pass
Run the training loop: import torch.nn as nn
import torch.optim as optim
# Define loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-3 )
# Training loop
for epoch in range (num_epochs):
# Forward pass
output = model(x) # Shape: (batch_size, seq_len, d_model)
# Compute loss (example: reconstruction)
loss = criterion(output, x)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
print ( f "Epoch { epoch } , Loss: { loss.item() :.4f} " )
Complete Training Example
Here’s a complete example showing forward and backward passes:
import torch
import torch.nn as nn
import torch.optim as optim
from lrnnx.models.lti import LRU
# Model configuration
d_model = 64
d_state = 64
batch_size = 32
seq_len = 128
# Initialize model
model = LRU( d_model = d_model, d_state = d_state).cuda()
model.train()
# Create sample data
x = torch.randn(batch_size, seq_len, d_model, device = "cuda" )
target = torch.randn(batch_size, seq_len, d_model, device = "cuda" )
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-3 )
# Single training step
optimizer.zero_grad()
output = model(x)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print ( f "Loss: { loss.item() :.4f} " )
Event-Based Training
Some models (S5, S6, Mamba) support event-based processing with custom integration timesteps:
from lrnnx.models.lti import S5
model = S5( d_model = 64 , d_state = 64 ).cuda()
model.train()
x = torch.randn( 32 , 128 , 64 , device = "cuda" )
# Provide integration timesteps (B, L)
integration_timesteps = torch.rand( 32 , 128 , device = "cuda" )
# Forward pass with custom timesteps
output = model(x, integration_timesteps = integration_timesteps)
When using integration_timesteps, ensure they are positive values representing the time intervals between events.
The library includes built-in benchmarking utilities to measure training throughput:
from benchmarks.benchmark_training import benchmark_sequence_length
def model_fn ():
return LRU( d_model = 64 , d_state = 64 ).cuda()
# Benchmark across different sequence lengths
results = benchmark_sequence_length(
model_fn,
seq_lengths = [ 128 , 256 , 512 , 1024 , 2048 ],
batch_size = 32 ,
repeats = 5
)
for seq_len, times in results.items():
avg_time = sum (times) / len (times)
print ( f "Seq len { seq_len } : { avg_time :.2f} ms" )
See the full benchmarking suite in benchmarks/benchmark_training.py for more examples including:
Varying model dimensions
Varying batch sizes
Multi-run statistics
Mixed Precision Training
For faster training with lower memory usage, use automatic mixed precision:
import torch
from torch.cuda.amp import autocast, GradScaler
from lrnnx.models.ltv import Mamba
model = Mamba( d_model = 64 , d_state = 16 ).cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()
x = torch.randn( 32 , 128 , 64 , device = "cuda" )
target = torch.randn( 32 , 128 , 64 , device = "cuda" )
for epoch in range (num_epochs):
optimizer.zero_grad()
# Forward pass with autocast
with autocast():
output = model(x)
loss = nn.functional.mse_loss(output, target)
# Backward pass with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Available Models
All models follow the same training interface:
LTI Models (Time-Invariant)
S4 - Structured State Space model
S4D - Diagonal variant of S4
S5 - Simplified State Space model
LRU - Linear Recurrent Unit
LTV Models (Time-Varying)
Mamba - Selective State Space model
S6/S7 - Extensions with selective mechanisms
RGLRU - Recurrent Gated Linear Recurrent Unit
Next Steps
Inference Guide Learn about fast autoregressive generation with CUDA graphs
Custom Kernels Understand the CUDA kernels powering lrnnx