Skip to main content
This tutorial walks through building a simple 1D U-Net-style audio denoiser that uses the Linear Recurrent Unit (LRU) as its core sequence model.

Overview

You’ll learn how to:
  • Create synthetic noisy audio data for training
  • Build an LRU U-Net model for denoising
  • Train the model and evaluate results
  • Visualize denoising performance
1
Import Dependencies
2
First, import the required libraries:
3
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import librosa
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import Audio, display
import soundfile as sf

from lrnnx.architectures.lru_unet import LRUUnet
4
Create the Dataset
5
The AudioDenoisingDataset generates synthetic mono audio clips where each sample is a pair of (noisy, clean) signals.
6
For every sample, it:
7
  • Creates a 1D time axis corresponding to max_length / sample_rate seconds of audio
  • Synthesizes a clean signal as a sum of a fundamental sinusoid (200–400 Hz) and its harmonics
  • Modulates the signal with an exponentially decaying envelope
  • Adds white Gaussian noise at a random level
  • Peak-normalizes both clean and noisy signals
  • 8
    class AudioDenoisingDataset(Dataset):
        """Synthetic noisy/clean 1D audio pairs."""
    
        def __init__(self, sample_rate: int = 16000, max_length: int = 16000, num_samples: int = 1000):
            self.sample_rate = sample_rate
            self.max_length = max_length
            self.num_samples = num_samples
            print(f"Using synthetic data with {num_samples} samples")
    
        def __len__(self) -> int:
            return self.num_samples
    
        def __getitem__(self, idx):
            return self._generate_synthetic_sample()
    
        def _generate_synthetic_sample(self):
            duration = self.max_length / self.sample_rate
            t = np.linspace(0, duration, self.max_length)
    
            # fundamental + harmonics
            freq1 = np.random.uniform(200, 400)
            freq2 = freq1 * 2
            freq3 = freq1 * 3
    
            clean = (
                np.sin(2 * np.pi * freq1 * t) +
                0.5 * np.sin(2 * np.pi * freq2 * t) +
                0.3 * np.sin(2 * np.pi * freq3 * t)
            )
    
            # exponential decay envelope
            envelope = np.exp(-t * np.random.uniform(0.5, 2.0))
            clean *= envelope
    
            # add white noise
            noise_level = np.random.uniform(0.1, 0.4)
            noise = np.random.normal(0, noise_level, len(clean))
            noisy = clean + noise
    
            # normalize
            clean = clean / (np.abs(clean).max() + 1e-8)
            noisy = noisy / (np.abs(noisy).max() + 1e-8)
    
            return (
                torch.tensor(noisy, dtype=torch.float32),
                torch.tensor(clean, dtype=torch.float32),
            )
    
    9
    Define the Training Function
    10
    The training function overfits on a single sample to verify the model can learn:
    11
    def train_audio_denoiser(
        model: nn.Module,
        dataset: Dataset,
        epochs: int = 1,
        batch_size: int = 1,
        learning_rate: float = 1e-2,
        device: str = "auto",
        steps: int = 1000,
    ):
        """
        Overfit a denoising model on a single synthetic audio sample to verify it can learn.
        
        Uses fixed batch from dataset for steps gradient updates. Logs loss history.
        Designed as a sanity-check: loss should drop toward zero on one example.
        """
        # Auto-detect device
        if device == "auto":
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            device = torch.device(device)
    
        print(f"Training on device: {device}")
    
        # Create dataloader
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
        # Setup model, loss, and optimizer
        model = model.to(device)
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
        # Extract single batch: (B, L) -> (B, 1, L)
        noisy, clean = next(iter(dataloader))
        noisy = noisy.unsqueeze(1).to(device)
        clean = clean.unsqueeze(1).to(device)
    
        # Track training loss
        history = {"losses": []}
    
        print(f"Overfitting on a single sample for {steps} steps...")
        model.train()
        for step in range(steps):
            optimizer.zero_grad()
            outputs = model(noisy)
            loss = criterion(outputs, clean)
            loss.backward()
            optimizer.step()
    
            history["losses"].append(loss.item())
            if (step + 1) % 50 == 0 or step == 0:
                print(f"Step {step+1}/{steps}, Loss: {loss.item():.8f}")
    
        return model, history
    
    12
    Create Visualization Function
    13
    This function visualizes and evaluates denoising results:
    14
    def visualize_denoising(
        model: nn.Module,
        dataset: Dataset,
        sample_idx: int = 0,
        device: str = "auto",
        figsize=(12, 8),
    ):
        """
        Visualize and evaluate denoising results for a single audio sample.
        
        Plots waveforms (noisy input, clean target, denoised output) and plays audio clips.
        Computes MSE and SNR improvement metrics.
        """
        # Auto-detect device
        if device == "auto":
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            device = torch.device(device)
    
        model.eval()
        model.to(device)
    
        # Fetch sample
        noisy, clean = dataset[sample_idx]
    
        # Run inference
        with torch.no_grad():
            noisy_batch = noisy.unsqueeze(0).unsqueeze(0).to(device)
            denoised = model(noisy_batch).squeeze().cpu()
    
        # Convert to NumPy
        noisy_np = noisy.numpy()
        clean_np = clean.numpy()
        denoised_np = denoised.numpy()
    
        # Create time axis
        sr = dataset.sample_rate
        t = np.linspace(0, len(noisy_np) / sr, len(noisy_np))
    
        # Plot waveforms
        fig, axes = plt.subplots(3, 1, figsize=figsize)
    
        axes[0].plot(t, noisy_np, "r-", alpha=0.7, label="Noisy")
        axes[0].set_title("Noisy audio")
        axes[0].set_ylabel("Amplitude")
        axes[0].grid(True, alpha=0.3)
    
        axes[1].plot(t, clean_np, "g-", alpha=0.7, label="Clean")
        axes[1].set_title("Clean audio (target)")
        axes[1].set_ylabel("Amplitude")
        axes[1].grid(True, alpha=0.3)
    
        axes[2].plot(t, denoised_np, "b-", alpha=0.7, label="Denoised")
        axes[2].set_title("Denoised audio (output)")
        axes[2].set_xlabel("Time (s)")
        axes[2].set_ylabel("Amplitude")
        axes[2].grid(True, alpha=0.3)
    
        plt.tight_layout()
        plt.show()
    
        # Compute metrics
        mse = np.mean((denoised_np - clean_np) ** 2)
        noise_power = np.mean((noisy_np - clean_np) ** 2)
        denoised_error = np.mean((denoised_np - clean_np) ** 2)
    
        orig_snr = 10 * np.log10(np.mean(clean_np ** 2) / (noise_power + 1e-10))
        den_snr = 10 * np.log10(np.mean(clean_np ** 2) / (denoised_error + 1e-10))
    
        # Print metrics
        print("Evaluation metrics:")
        print(f"MSE: {mse:.6f}")
        print(f"Original SNR: {orig_snr:.2f} dB")
        print(f"Denoised SNR: {den_snr:.2f} dB")
        print(f"SNR improvement: {den_snr - orig_snr:.2f} dB")
    
    15
    Train the Model
    16
    Now put it all together and train the model:
    17
    def main():
        print("=== Audio Denoising with LRUUnet ===")
    
        # Create single 10-second synthetic sample
        dataset = AudioDenoisingDataset(
            num_samples=1,
            sample_rate=16000,
            max_length=16000 * 10,  # 10 seconds
        )
    
        # Create LRU U-Net model
        model = LRUUnet(
            in_channels=1,
            out_channels=1,
            channels=[4, 8],
            resample_factors=[4, 4],
            pre_conv=False,
            causal=True,
        )
    
        # Train model
        model, history = train_audio_denoiser(
            model=model,
            dataset=dataset,
            epochs=1,
            batch_size=1,
            learning_rate=1e-2,
            device="auto",
            steps=1000,
        )
    
        # Visualize results
        visualize_denoising(model, dataset, sample_idx=0)
    
        return model, history
    
    if __name__ == "__main__":
        trained_model, training_history = main()
        print("Audio denoising training complete")
    

    Results Analysis

    Visual Results

    The three-panel visualization shows: Noisy Audio Input (Top Panel - Red)
    • The clean signal is almost entirely masked by pervasive white noise
    • The underlying structure is barely discernible
    Clean Audio Target (Middle Panel - Green)
    • The pristine target signal: a decaying sinusoidal waveform
    • This serves as the ground truth for denoising
    Denoised Output (Bottom Panel - Blue)
    • The denoised output closely resembles the clean audio
    • The model successfully suppressed the white noise
    • Minimal artifacts are visible, indicating strong signal reconstruction

    Architecture Strengths

    • Combines U-Net’s encoder-decoder structure with skip connections for local feature preservation
    • Uses LRU blocks for sequence modeling in the bottleneck
    • Custom Conv1d-based pooling provides differentiable down/upsampling
    • Training converges quickly on small synthetic data due to modest model size and strong inductive bias

    Conclusion

    This U-Net with LRU integration offers a solid foundation for audio denoising, leveraging multi-scale feature extraction and state-space sequence modeling for effective noise reduction on synthetic waveforms. For real-world applications, you would:
    • Train on larger, more diverse datasets
    • Experiment with different model sizes and architectures
    • Use perceptual loss functions for better audio quality
    • Apply data augmentation techniques

    Build docs developers (and LLMs) love