Overview
TheLRU_UNet is a U-Net architecture built with Linear Recurrent Units (LRUs) for sequence-to-sequence tasks. It follows the classic U-Net design with an encoder-decoder structure, skip connections, and hierarchical feature processing at multiple resolutions.
This architecture is particularly well-suited for tasks like audio denoising, as demonstrated in the aTENNuate paper.
Architecture
The model consists of three main components:-
Encoder (Downsampling path):
- Each stage contains an LRU layer followed by downsampling
- Downsampling doubles the number of channels and reduces sequence length
- Skip connections preserve features from each resolution
-
Bottleneck:
- Central LRU layer processing the most compressed representation
-
Decoder (Upsampling path):
- Each stage upsamples and halves the number of channels
- Skip connections from encoder are added before LRU processing
- Reconstructs the original sequence resolution
Class Signature
Parameters
Input feature dimension (number of channels)
Hidden state dimension for the LRU layers
Number of downsampling/upsampling stages. The total sequence length reduction factor is
downsample_factor ** n_layers.Downsampling/upsampling factor for each stage. The sequence length is reduced by this factor at each encoder stage.
Usage Example
Audio Denoising
Sequence-to-Sequence Processing
Variable-Length Sequences
The model automatically handles padding for sequences that aren’t divisible by the total downsampling factor:Methods
forward
x(torch.Tensor): Input sequence of shape(B, C_in, T)where:Bis batch sizeC_inis the number of channels (must equald_model)Tis the sequence length
torch.Tensor: Processed sequence of shape(B, C_in, T)(same shape as input)
downsample_factor ** n_layers.
Architecture Details
Channel Progression
Withd_model=64 and n_layers=3:
- Input: 64 channels
- After stage 1: 128 channels
- After stage 2: 256 channels
- After stage 3: 512 channels (bottleneck)
- After upsampling stage 1: 256 channels
- After upsampling stage 2: 128 channels
- After upsampling stage 3: 64 channels (output)
Sequence Length Progression
Withdownsample_factor=2 and n_layers=3, input length T:
- Input: T
- After stage 1: T/2
- After stage 2: T/4
- After stage 3: T/8 (bottleneck)
- After upsampling stage 1: T/4
- After upsampling stage 2: T/2
- After upsampling stage 3: T (output)
Use Cases
- Audio denoising: Remove noise from audio signals (see aTENNuate tutorial)
- Speech enhancement: Improve speech quality in noisy conditions
- Signal restoration: Reconstruct clean signals from corrupted inputs
- Time series processing: Any sequence-to-sequence transformation task
References
See Also
- Audio Denoising Tutorial - Complete guide to using LRU_UNet for audio denoising
- LRU - Linear Recurrent Unit layer
