Skip to main content

Overview

The LRU_UNet is a U-Net architecture built with Linear Recurrent Units (LRUs) for sequence-to-sequence tasks. It follows the classic U-Net design with an encoder-decoder structure, skip connections, and hierarchical feature processing at multiple resolutions. This architecture is particularly well-suited for tasks like audio denoising, as demonstrated in the aTENNuate paper.

Architecture

The model consists of three main components:
  1. Encoder (Downsampling path):
    • Each stage contains an LRU layer followed by downsampling
    • Downsampling doubles the number of channels and reduces sequence length
    • Skip connections preserve features from each resolution
  2. Bottleneck:
    • Central LRU layer processing the most compressed representation
  3. Decoder (Upsampling path):
    • Each stage upsamples and halves the number of channels
    • Skip connections from encoder are added before LRU processing
    • Reconstructs the original sequence resolution

Class Signature

LRU_UNet(
    d_model: int,
    d_state: int,
    n_layers: int,
    downsample_factor: int = 2,
)

Parameters

d_model
int
required
Input feature dimension (number of channels)
d_state
int
required
Hidden state dimension for the LRU layers
n_layers
int
required
Number of downsampling/upsampling stages. The total sequence length reduction factor is downsample_factor ** n_layers.
downsample_factor
int
default:"2"
Downsampling/upsampling factor for each stage. The sequence length is reduced by this factor at each encoder stage.

Usage Example

Audio Denoising

import torch
from lrnnx.architectures.lru_unet import LRU_UNet

# Create U-Net for audio denoising
model = LRU_UNet(
    d_model=64,      # 64 audio channels/features
    d_state=128,     # LRU hidden state size
    n_layers=4,      # 4 downsampling stages (16x total reduction)
    downsample_factor=2,
)

# Process audio sequence
# Input: (batch_size, channels, time_steps)
noisy_audio = torch.randn(2, 64, 16000)  # 2 examples, 64 features, 16000 time steps
denoised = model(noisy_audio)  # (2, 64, 16000)

Sequence-to-Sequence Processing

# U-Net with 3 stages for sequence transformation
model = LRU_UNet(
    d_model=128,
    d_state=256,
    n_layers=3,
    downsample_factor=4,  # More aggressive downsampling
)

# Input shape: (batch, channels, sequence_length)
x = torch.randn(4, 128, 1024)
output = model(x)  # (4, 128, 1024) - same shape as input

Variable-Length Sequences

The model automatically handles padding for sequences that aren’t divisible by the total downsampling factor:
model = LRU_UNet(d_model=32, d_state=64, n_layers=2, downsample_factor=2)

# Sequence length = 100 (not divisible by 2^2 = 4)
x = torch.randn(1, 32, 100)
output = model(x)  # (1, 32, 100) - padding is automatically applied and removed

Methods

forward

forward(x: torch.Tensor) -> torch.Tensor
Forward pass through the U-Net. Arguments:
  • x (torch.Tensor): Input sequence of shape (B, C_in, T) where:
    • B is batch size
    • C_in is the number of channels (must equal d_model)
    • T is the sequence length
Returns:
  • torch.Tensor: Processed sequence of shape (B, C_in, T) (same shape as input)
Note: The model automatically handles padding when the sequence length is not divisible by downsample_factor ** n_layers.

Architecture Details

Channel Progression

With d_model=64 and n_layers=3:
  • Input: 64 channels
  • After stage 1: 128 channels
  • After stage 2: 256 channels
  • After stage 3: 512 channels (bottleneck)
  • After upsampling stage 1: 256 channels
  • After upsampling stage 2: 128 channels
  • After upsampling stage 3: 64 channels (output)

Sequence Length Progression

With downsample_factor=2 and n_layers=3, input length T:
  • Input: T
  • After stage 1: T/2
  • After stage 2: T/4
  • After stage 3: T/8 (bottleneck)
  • After upsampling stage 1: T/4
  • After upsampling stage 2: T/2
  • After upsampling stage 3: T (output)

Use Cases

  • Audio denoising: Remove noise from audio signals (see aTENNuate tutorial)
  • Speech enhancement: Improve speech quality in noisy conditions
  • Signal restoration: Reconstruct clean signals from corrupted inputs
  • Time series processing: Any sequence-to-sequence transformation task

References

See Also

Build docs developers (and LLMs) love