Skip to main content

Overview

LRU_UNet is a U-Net architecture built with Linear Recurrent Unit (LRU) layers for sequence-to-sequence tasks. It features an encoder-decoder structure with skip connections, using LRU layers for temporal processing and convolutional layers for downsampling/upsampling.

Class Definition

from lrnnx.architectures import LRU_UNet

model = LRU_UNet(
    d_model=128,
    d_state=64,
    n_layers=3,
    downsample_factor=2
)

Parameters

d_model
int
required
Input feature dimension (number of channels).
d_state
int
required
Hidden state dimension for the LRU layers.
n_layers
int
required
Number of downsampling/upsampling stages in the U-Net.
downsample_factor
int
default:"2"
Downsampling/upsampling factor for each stage. The total downsampling factor is downsample_factor ** n_layers.

Methods

forward

output = model.forward(x)
Forward pass through the U-Net.
x
torch.Tensor
required
Input sequence of shape (B, C_in, T) where:
  • B is batch size
  • C_in is number of input channels (must equal d_model)
  • T is sequence length
output
torch.Tensor
Processed sequence of shape (B, C_in, T) with the same dimensions as input.

Architecture Details

The U-Net consists of three main components:

Encoder (Downsampling Path)

  • Each stage contains:
    • LRU layer for temporal processing
    • Strided Conv1d for downsampling (doubles channels, reduces length by downsample_factor)
  • Skip connections are saved at each stage

Bottleneck

  • Single LRU layer at the lowest resolution

Decoder (Upsampling Path)

  • Each stage contains:
    • ConvTranspose1d for upsampling (halves channels, increases length by downsample_factor)
    • LRU layer for temporal processing
    • Skip connection addition from encoder

Padding Handling

The model automatically pads input sequences to be divisible by the total downsampling factor and crops the output to the original length.

Example Usage

import torch
from lrnnx.architectures import LRU_UNet

# Create U-Net model
model = LRU_UNet(
    d_model=64,      # 64 input channels
    d_state=32,      # 32-dimensional hidden state
    n_layers=3,      # 3 stages of downsampling
    downsample_factor=2
).cuda()

# Input: batch=4, channels=64, length=1000
x = torch.randn(4, 64, 1000).cuda()

# Forward pass
output = model(x)
print(output.shape)  # (4, 64, 1000)

# Channel progression through layers:
# Encoder: 64 -> 128 -> 256 -> 512
# Decoder: 512 -> 256 -> 128 -> 64

Use Cases

  • Sequence denoising
  • Audio source separation
  • Time series forecasting with multi-scale features
  • Signal enhancement tasks
  • Any sequence-to-sequence task requiring hierarchical feature extraction

Notes

  • Input sequences are automatically padded if their length is not divisible by downsample_factor ** n_layers
  • The model maintains the original sequence length by cropping padded outputs
  • Skip connections help preserve fine-grained temporal information from the encoder
  • LRU layers operate on (B, T, C) format internally, while the model interface uses (B, C, T) format

Build docs developers (and LLMs) love