Overview
Convolution operations for LRNN models using FFT (Fast Fourier Transform) for efficient computation. These functions implement optimized strategies for convolution-based forward passes in state space models. Reference: arxiv:2409.03377Functions
opt_ssm_forward
-
Strategy 1: When
(1/H_in + 1/H_out) > (1/B + 1/N)andH_in * H_out <= N- Precompute full kernel:
kernel = einsum("on,nl,ni->loi", C, K, B_) - Apply convolution:
fft_conv("bli,loi->blo", x, kernel)
- Precompute full kernel:
-
Strategy 2: When
(1/H_in + 1/H_out) <= (1/B + 1/N)andN <= H_in- Project input:
x_proj = einsum("blh,nh->bln", x, B_) - Convolve projected input:
fft_conv("bln,ln->bln", x_proj, K.T) - Apply output projection:
einsum("bln,hn->blh", x_conv, C)
- Project input:
-
Fallback: When neither strategy is optimal
- Direct computation:
fft_conv("blh,nl,nh,on->blo", x, K, B_, C)
- Direct computation:
Input tensor, shape
(B, L, H) where:Bis the batch sizeLis the sequence lengthHis the input dimension
Kernel tensor, shape
(L, H, H) or (L, N) depending on the model configuration.Normalized input projection matrix, shape
(N, H) where N is the state dimension.Output projection matrix, shape
(H, N).Output tensor, shape
(B, L, H), representing the convolved sequence.fft_conv
opt_ssm_forward and supports multiple argument patterns.
Einsum equation string specifying the contraction pattern (e.g.,
"bli,loi->blo").Input tensor, shape
(B, L, H) or (B, L, N).Variable arguments depending on the convolution pattern:
- Single argument: Kernel tensor of shape
(L, H, H) - Multiple arguments: Separate
K,B_norm, andCtensors
Convolved output tensor, shape
(B, L, H) or (B, L, N) depending on the input configuration.- Performs FFT with padding to
2*Lto avoid circular convolution artifacts - Converts tensors to complex float (
cfloat) for FFT operations - Returns real part after inverse FFT, truncated to original sequence length
L
FFTConvS4 Module
Class
Model dimension (in CNN terminology, the number of “channels”).
Maximum kernel length. Use
None for a global kernel that adapts to input length.Number of “heads”; the SSM maps 1-dimensional input to C-dimensional output.
Whether to swap channel ordering in the computation.
Backbone axis ordering. If
True, expects input shape (B, D, L). If False, expects (B, L, D).Dropout probability applied to the output.
If
True, ties the dropout mask across the sequence length dimension.Kernel dropout probability, applied to the convolution kernel.
Kernel algorithm specification:
"s4"- DPLR (Diagonal Plus Low-Rank) parameterization"s4d"- Diagonal parameterization
param_config is provided.Dictionary containing references to SSM parameters (A, B, C, dt, P, etc.). Used with
kernel_type to configure the kernel.Alternative kernel specification. Either this or
param_config must be provided.Additional keyword arguments forwarded to the kernel class constructor.
Methods
forward
Input tensor. Shape depends on
transposed parameter:- If
transposed=True:(B, D, L) - If
transposed=False:(B, L, D)
Recurrent state from previous time step. Used for stateful/recurrent processing.
Rate parameter for kernel computation, useful for temporal downsampling.
Additional keyword arguments (absorbs
return_output, transformer source mask, etc.).Convolution output, shape
(B, C, H, L) where C is the number of channels.Updated state for recurrent mode. Returns
None if state was not provided.step
Input tensor at current time step, shape
(B, H).Recurrent state, shape
(B, H, N) where N is the state dimension.Output at current time step, shape
(B, C, H).Updated state for next time step, shape
(B, H, N).setup_step
step method.
default_state
Batch dimensions for the state tensor.
Device on which to create the state tensor.
Initialized state tensor with appropriate shape and device.
Properties
Output dimension, computed as
d_model * channels.Example Usage
Performance Considerations
- The
opt_ssm_forwardfunction automatically selects the most efficient computation strategy - FFT operations are performed with padding to avoid circular convolution
- Kernel dropout can be used for regularization without recomputing the FFT
- The
@torch.compiler.disabledecorator onfft_convprevents torch compilation issues with FFT operations
See Also
- LRNN Base Class - Base class that uses these convolution operations
- Discretization Functions - Methods for discretizing continuous-time systems
