Linear RNNs are based on continuous-time state-space models:
h'(t) = A h(t) + B x(t)y(t) = C h(t) + D x(t)
However, neural networks operate on discrete sequences (text tokens, audio samples, video frames). Discretization is the process of converting these continuous-time equations into discrete-time recurrence relations that can be computed on digital hardware:
# After discretizationh[k+1] = A_bar @ h[k] + B_bar @ x[k]y[k] = C @ h[k] + D @ x[k]
Where A_bar and B_bar are the discretized matrices obtained from continuous-time A and B.
The discretization method you choose affects the model’s stability, accuracy, and how it handles different temporal patterns.
ZOH is the most widely used discretization method in modern linear RNNs. It assumes the input signal is piecewise constant (held) between timesteps:Aˉ=exp(ΔA)γˉ=A−1(Aˉ−I)Where Δ is the discretization step size (learned or fixed), and γ_bar is used to compute B_bar = γ_bar * B.
def zoh( A: Tensor, delta: Tensor, integration_timesteps: Optional[Tensor] = None) -> tuple[Tensor, Tensor]: """ Zero-Order Hold (ZOH) discretization method, used across most models. Args: A (torch.Tensor): The continuous-time state matrix. delta (torch.Tensor): The discretization step size. integration_timesteps (torch.Tensor, optional): Not used in ZOH. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - A_bar : The discretized system matrix. - gamma_bar : The input normalizer. """ Identity = torch.ones(A.shape[0], device=A.device) A_bar = torch.exp(delta * A) gamma_bar = (1 / A) * (A_bar - Identity) return A_bar, gamma_bar
The bilinear method (also called Tustin’s method or trapezoidal rule) was the original discretization used in S4:Aˉ=(I+0.5ΔA)−1(I−0.5ΔA)γˉ=(I+0.5ΔA)−1Δ
The Dirac method treats inputs as instantaneous impulses (Dirac delta functions) rather than sustained signals:Aˉ=exp(ΔA)γˉ=1.0Note that γ_bar = 1.0 (constant), unlike ZOH where it depends on A.
Asynchronous discretization allows different timesteps at each sequence position, useful for irregular event streams:Aˉ[t]=exp(Δ⋅timesteps[t]⋅A)γˉ[t]=A−1(exp(Δ⋅A)−I)
Asynchronous discretization is only supported for LTV models. LTI models cannot use this method because it creates time-varying dynamics.
from lrnnx.models.ltv import Mamba# Create model with async discretizationmodel = Mamba( d_model=64, d_state=16, d_conv=4, discretization="async")# Provide timesteps (e.g., time differences between events)x = torch.randn(2, 1024, 64)timesteps = torch.abs(torch.randn(2, 1024)) # Variable time deltas# Pass timesteps to forwardy = model(x, integration_timesteps=timesteps)
from lrnnx.core.discretization import DISCRETIZE_FNSDISCRETIZE_FNS["my_custom"] = my_custom_discretization# Now you can use itfrom lrnnx.models.lti import S5model = S5(d_model=64, d_state=64, discretization="my_custom")