Overview
You’ll learn how to:- Create synthetic noisy audio data for training
- Build an LRU U-Net model for denoising
- Train the model and evaluate results
- Visualize denoising performance
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import librosa
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import Audio, display
import soundfile as sf
from lrnnx.architectures.lru_unet import LRUUnet
The
AudioDenoisingDataset generates synthetic mono audio clips where each sample is a pair of (noisy, clean) signals.max_length / sample_rate seconds of audioclass AudioDenoisingDataset(Dataset):
"""Synthetic noisy/clean 1D audio pairs."""
def __init__(self, sample_rate: int = 16000, max_length: int = 16000, num_samples: int = 1000):
self.sample_rate = sample_rate
self.max_length = max_length
self.num_samples = num_samples
print(f"Using synthetic data with {num_samples} samples")
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, idx):
return self._generate_synthetic_sample()
def _generate_synthetic_sample(self):
duration = self.max_length / self.sample_rate
t = np.linspace(0, duration, self.max_length)
# fundamental + harmonics
freq1 = np.random.uniform(200, 400)
freq2 = freq1 * 2
freq3 = freq1 * 3
clean = (
np.sin(2 * np.pi * freq1 * t) +
0.5 * np.sin(2 * np.pi * freq2 * t) +
0.3 * np.sin(2 * np.pi * freq3 * t)
)
# exponential decay envelope
envelope = np.exp(-t * np.random.uniform(0.5, 2.0))
clean *= envelope
# add white noise
noise_level = np.random.uniform(0.1, 0.4)
noise = np.random.normal(0, noise_level, len(clean))
noisy = clean + noise
# normalize
clean = clean / (np.abs(clean).max() + 1e-8)
noisy = noisy / (np.abs(noisy).max() + 1e-8)
return (
torch.tensor(noisy, dtype=torch.float32),
torch.tensor(clean, dtype=torch.float32),
)
def train_audio_denoiser(
model: nn.Module,
dataset: Dataset,
epochs: int = 1,
batch_size: int = 1,
learning_rate: float = 1e-2,
device: str = "auto",
steps: int = 1000,
):
"""
Overfit a denoising model on a single synthetic audio sample to verify it can learn.
Uses fixed batch from dataset for steps gradient updates. Logs loss history.
Designed as a sanity-check: loss should drop toward zero on one example.
"""
# Auto-detect device
if device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
print(f"Training on device: {device}")
# Create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
# Setup model, loss, and optimizer
model = model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Extract single batch: (B, L) -> (B, 1, L)
noisy, clean = next(iter(dataloader))
noisy = noisy.unsqueeze(1).to(device)
clean = clean.unsqueeze(1).to(device)
# Track training loss
history = {"losses": []}
print(f"Overfitting on a single sample for {steps} steps...")
model.train()
for step in range(steps):
optimizer.zero_grad()
outputs = model(noisy)
loss = criterion(outputs, clean)
loss.backward()
optimizer.step()
history["losses"].append(loss.item())
if (step + 1) % 50 == 0 or step == 0:
print(f"Step {step+1}/{steps}, Loss: {loss.item():.8f}")
return model, history
def visualize_denoising(
model: nn.Module,
dataset: Dataset,
sample_idx: int = 0,
device: str = "auto",
figsize=(12, 8),
):
"""
Visualize and evaluate denoising results for a single audio sample.
Plots waveforms (noisy input, clean target, denoised output) and plays audio clips.
Computes MSE and SNR improvement metrics.
"""
# Auto-detect device
if device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
model.eval()
model.to(device)
# Fetch sample
noisy, clean = dataset[sample_idx]
# Run inference
with torch.no_grad():
noisy_batch = noisy.unsqueeze(0).unsqueeze(0).to(device)
denoised = model(noisy_batch).squeeze().cpu()
# Convert to NumPy
noisy_np = noisy.numpy()
clean_np = clean.numpy()
denoised_np = denoised.numpy()
# Create time axis
sr = dataset.sample_rate
t = np.linspace(0, len(noisy_np) / sr, len(noisy_np))
# Plot waveforms
fig, axes = plt.subplots(3, 1, figsize=figsize)
axes[0].plot(t, noisy_np, "r-", alpha=0.7, label="Noisy")
axes[0].set_title("Noisy audio")
axes[0].set_ylabel("Amplitude")
axes[0].grid(True, alpha=0.3)
axes[1].plot(t, clean_np, "g-", alpha=0.7, label="Clean")
axes[1].set_title("Clean audio (target)")
axes[1].set_ylabel("Amplitude")
axes[1].grid(True, alpha=0.3)
axes[2].plot(t, denoised_np, "b-", alpha=0.7, label="Denoised")
axes[2].set_title("Denoised audio (output)")
axes[2].set_xlabel("Time (s)")
axes[2].set_ylabel("Amplitude")
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Compute metrics
mse = np.mean((denoised_np - clean_np) ** 2)
noise_power = np.mean((noisy_np - clean_np) ** 2)
denoised_error = np.mean((denoised_np - clean_np) ** 2)
orig_snr = 10 * np.log10(np.mean(clean_np ** 2) / (noise_power + 1e-10))
den_snr = 10 * np.log10(np.mean(clean_np ** 2) / (denoised_error + 1e-10))
# Print metrics
print("Evaluation metrics:")
print(f"MSE: {mse:.6f}")
print(f"Original SNR: {orig_snr:.2f} dB")
print(f"Denoised SNR: {den_snr:.2f} dB")
print(f"SNR improvement: {den_snr - orig_snr:.2f} dB")
def main():
print("=== Audio Denoising with LRUUnet ===")
# Create single 10-second synthetic sample
dataset = AudioDenoisingDataset(
num_samples=1,
sample_rate=16000,
max_length=16000 * 10, # 10 seconds
)
# Create LRU U-Net model
model = LRUUnet(
in_channels=1,
out_channels=1,
channels=[4, 8],
resample_factors=[4, 4],
pre_conv=False,
causal=True,
)
# Train model
model, history = train_audio_denoiser(
model=model,
dataset=dataset,
epochs=1,
batch_size=1,
learning_rate=1e-2,
device="auto",
steps=1000,
)
# Visualize results
visualize_denoising(model, dataset, sample_idx=0)
return model, history
if __name__ == "__main__":
trained_model, training_history = main()
print("Audio denoising training complete")
Results Analysis
Visual Results
The three-panel visualization shows: Noisy Audio Input (Top Panel - Red)- The clean signal is almost entirely masked by pervasive white noise
- The underlying structure is barely discernible
- The pristine target signal: a decaying sinusoidal waveform
- This serves as the ground truth for denoising
- The denoised output closely resembles the clean audio
- The model successfully suppressed the white noise
- Minimal artifacts are visible, indicating strong signal reconstruction
Architecture Strengths
- Combines U-Net’s encoder-decoder structure with skip connections for local feature preservation
- Uses LRU blocks for sequence modeling in the bottleneck
- Custom
Conv1d-based pooling provides differentiable down/upsampling - Training converges quickly on small synthetic data due to modest model size and strong inductive bias
Conclusion
This U-Net with LRU integration offers a solid foundation for audio denoising, leveraging multi-scale feature extraction and state-space sequence modeling for effective noise reduction on synthetic waveforms. For real-world applications, you would:- Train on larger, more diverse datasets
- Experiment with different model sizes and architectures
- Use perceptual loss functions for better audio quality
- Apply data augmentation techniques
