Skip to main content

Overview

Convolution operations for LRNN models using FFT (Fast Fourier Transform) for efficient computation. These functions implement optimized strategies for convolution-based forward passes in state space models. Reference: arxiv:2409.03377

Functions

opt_ssm_forward

opt_ssm_forward(x, K, B_, C) -> Tensor
Optimized FFT convolution with automatic strategy selection. This function intelligently chooses between three different computation strategies based on tensor dimensions to minimize computational cost. Strategy Selection:
  1. Strategy 1: When (1/H_in + 1/H_out) > (1/B + 1/N) and H_in * H_out <= N
    • Precompute full kernel: kernel = einsum("on,nl,ni->loi", C, K, B_)
    • Apply convolution: fft_conv("bli,loi->blo", x, kernel)
  2. Strategy 2: When (1/H_in + 1/H_out) <= (1/B + 1/N) and N <= H_in
    • Project input: x_proj = einsum("blh,nh->bln", x, B_)
    • Convolve projected input: fft_conv("bln,ln->bln", x_proj, K.T)
    • Apply output projection: einsum("bln,hn->blh", x_conv, C)
  3. Fallback: When neither strategy is optimal
    • Direct computation: fft_conv("blh,nl,nh,on->blo", x, K, B_, C)
x
torch.Tensor
required
Input tensor, shape (B, L, H) where:
  • B is the batch size
  • L is the sequence length
  • H is the input dimension
K
torch.Tensor
required
Kernel tensor, shape (L, H, H) or (L, N) depending on the model configuration.
B_
torch.Tensor
required
Normalized input projection matrix, shape (N, H) where N is the state dimension.
C
torch.Tensor
required
Output projection matrix, shape (H, N).
output
torch.Tensor
Output tensor, shape (B, L, H), representing the convolved sequence.

fft_conv

fft_conv(equation, input, *args) -> Tensor
FFT-based convolution operation with flexible einsum equations. This is a lower-level function used by opt_ssm_forward and supports multiple argument patterns.
equation
str
required
Einsum equation string specifying the contraction pattern (e.g., "bli,loi->blo").
input
torch.Tensor
required
Input tensor, shape (B, L, H) or (B, L, N).
*args
torch.Tensor
required
Variable arguments depending on the convolution pattern:
  • Single argument: Kernel tensor of shape (L, H, H)
  • Multiple arguments: Separate K, B_norm, and C tensors
output
torch.Tensor
Convolved output tensor, shape (B, L, H) or (B, L, N) depending on the input configuration.
Implementation Details:
  • Performs FFT with padding to 2*L to avoid circular convolution artifacts
  • Converts tensors to complex float (cfloat) for FFT operations
  • Returns real part after inverse FFT, truncated to original sequence length L

FFTConvS4 Module

Class

FFTConvS4(d_model, l_max=None, channels=1, swap_channels=False, 
          transposed=True, dropout=0.0, tie_dropout=False, 
          drop_kernel=0.0, kernel_type=None, param_config=None, 
          kernel=None, **kernel_args)
PyTorch module implementing FFT convolution around a learnable convolution kernel. This is the main building block for S4-style models.
d_model
int
required
Model dimension (in CNN terminology, the number of “channels”).
l_max
int
Maximum kernel length. Use None for a global kernel that adapts to input length.
channels
int
default:"1"
Number of “heads”; the SSM maps 1-dimensional input to C-dimensional output.
swap_channels
bool
default:"False"
Whether to swap channel ordering in the computation.
transposed
bool
default:"True"
Backbone axis ordering. If True, expects input shape (B, D, L). If False, expects (B, L, D).
dropout
float
default:"0.0"
Dropout probability applied to the output.
tie_dropout
bool
default:"False"
If True, ties the dropout mask across the sequence length dimension.
drop_kernel
float
default:"0.0"
Kernel dropout probability, applied to the convolution kernel.
kernel_type
str
Kernel algorithm specification:
  • "s4" - DPLR (Diagonal Plus Low-Rank) parameterization
  • "s4d" - Diagonal parameterization
Required when param_config is provided.
param_config
dict
Dictionary containing references to SSM parameters (A, B, C, dt, P, etc.). Used with kernel_type to configure the kernel.
kernel
str
Alternative kernel specification. Either this or param_config must be provided.
**kernel_args
dict
Additional keyword arguments forwarded to the kernel class constructor.

Methods

forward

forward(x, state=None, rate=1.0, **kwargs) -> tuple[Tensor, Tensor | None]
Forward pass through the FFTConvS4 module.
x
torch.Tensor
required
Input tensor. Shape depends on transposed parameter:
  • If transposed=True: (B, D, L)
  • If transposed=False: (B, L, D)
state
torch.Tensor
Recurrent state from previous time step. Used for stateful/recurrent processing.
rate
float
default:"1.0"
Rate parameter for kernel computation, useful for temporal downsampling.
**kwargs
dict
Additional keyword arguments (absorbs return_output, transformer source mask, etc.).
y
torch.Tensor
Convolution output, shape (B, C, H, L) where C is the number of channels.
next_state
torch.Tensor | None
Updated state for recurrent mode. Returns None if state was not provided.

step

step(x, state) -> tuple[Tensor, Tensor]
Step one time step as a recurrent model. Intended for use during validation or autoregressive generation.
x
torch.Tensor
required
Input tensor at current time step, shape (B, H).
state
torch.Tensor
required
Recurrent state, shape (B, H, N) where N is the state dimension.
y
torch.Tensor
Output at current time step, shape (B, C, H).
next_state
torch.Tensor
Updated state for next time step, shape (B, H, N).

setup_step

setup_step(**kwargs)
Prepare the module for step-by-step (recurrent) inference. Must be called before using the step method.

default_state

default_state(*batch_shape, device=None) -> Tensor
Create a default initial state for recurrent processing.
*batch_shape
int
Batch dimensions for the state tensor.
device
torch.device
Device on which to create the state tensor.
state
torch.Tensor
Initialized state tensor with appropriate shape and device.

Properties

d_output
int
Output dimension, computed as d_model * channels.

Example Usage

import torch
from lrnnx.core.convolution import opt_ssm_forward, FFTConvS4

# Low-level usage with opt_ssm_forward
B, L, H, N = 32, 1024, 64, 128
x = torch.randn(B, L, H)
K = torch.randn(L, N)
B_ = torch.randn(N, H)
C = torch.randn(H, N)

output = opt_ssm_forward(x, K, B_, C)  # (B, L, H)

# High-level usage with FFTConvS4 module
conv_layer = FFTConvS4(
    d_model=64,
    l_max=1024,
    channels=4,
    kernel="s4d",
    dropout=0.1
)

x = torch.randn(32, 64, 1024)  # (B, D, L) with transposed=True
y, _ = conv_layer(x)  # (B, 4, 64, 1024)

Performance Considerations

  • The opt_ssm_forward function automatically selects the most efficient computation strategy
  • FFT operations are performed with padding to avoid circular convolution
  • Kernel dropout can be used for regularization without recomputing the FFT
  • The @torch.compiler.disable decorator on fft_conv prevents torch compilation issues with FFT operations

See Also

Build docs developers (and LLMs) love