Skip to main content

Overview

The Block class is a fundamental building block for constructing neural architectures in lrnnx. It wraps a mixer module (e.g., attention, LRNN) with layer normalization and residual connections, optionally including an MLP module. This implementation has a unique structure compared to standard prenorm Transformer blocks:
  • Standard: LN → MHA/MLP → Add
  • This Block: Add → LN → Mixer
This design enables fusing the add and normalization operations for better performance.

Class Definition

from lrnnx.layers.block import Block

block = Block(
    dim=768,
    mixer_cls=SomeMixerClass,
    mlp_cls=GatedMLP,
    norm_cls=nn.LayerNorm,
    fused_add_norm=True,
    residual_in_fp32=False
)

Parameters

dim
int
required
The hidden dimension size for the block.
mixer_cls
type
required
The mixer class to instantiate (e.g., MHA, LRNN). This class will be initialized with dim as its first argument.
mlp_cls
type
required
The MLP class to instantiate, or nn.Identity if no MLP is desired. When using an MLP, a second normalization layer is applied.
norm_cls
type
default:"nn.LayerNorm"
The normalization class to use. Supports nn.LayerNorm and RMSNorm (with Triton acceleration when available).
fused_add_norm
bool
default:"True"
Whether to use Triton fused add and normalization operations for improved performance. Only works with nn.LayerNorm and RMSNorm.
residual_in_fp32
bool
default:"False"
Whether to keep the residual connection in fp32 precision for numerical stability.

Methods

forward

hidden_states, residual = block.forward(
    hidden_states,
    residual=None,
    inference_params=None,
    **mixer_kwargs
)
Pass the input through the block with normalization, mixer, and optional MLP.

Parameters

hidden_states
torch.Tensor
required
The sequence input to the block of shape (batch_size, seq_len, dim).
residual
torch.Tensor
default:"None"
The residual connection from the previous block. If None, uses hidden_states as the residual. The computation is: hidden_states = Mixer(LN(residual)).
inference_params
Any
default:"None"
Parameters used during autoregressive generation/inference. Passed to the mixer’s forward method if supported.
**mixer_kwargs
dict
Additional keyword arguments passed directly to the underlying mixer module’s forward method.

Returns

hidden_states
torch.Tensor
The output of the block after applying the mixer (and optionally MLP) of shape (batch_size, seq_len, dim).
residual
torch.Tensor
The updated residual tensor to be passed to the next block of shape (batch_size, seq_len, dim).

allocate_inference_cache

cache = block.allocate_inference_cache(
    batch_size=1,
    max_seqlen=2048,
    dtype=torch.float16
)
Allocate inference cache for the mixer module (e.g., KV cache for attention).

Parameters

batch_size
int
required
The batch size for inference.
max_seqlen
int
required
The maximum sequence length for inference.
dtype
torch.dtype
default:"None"
The data type for the cache tensors. If None, uses the mixer’s default dtype.
**kwargs
dict
Additional keyword arguments to pass to the mixer’s cache allocation method.

Returns

cache
Any
The allocated cache object returned by the mixer. The structure depends on the specific mixer implementation.

Usage Example

import torch
import torch.nn as nn
from lrnnx.layers.block import Block
from lrnnx.layers.mha import MHA
from lrnnx.layers.mlp import GatedMLP

# Create a block with attention and MLP
block = Block(
    dim=768,
    mixer_cls=lambda dim: MHA(
        embed_dim=dim,
        num_heads=12,
        causal=True
    ),
    mlp_cls=lambda dim: GatedMLP(in_features=dim),
    norm_cls=nn.LayerNorm,
    fused_add_norm=True
)

# Forward pass
x = torch.randn(2, 128, 768)  # (batch, seq_len, dim)
hidden_states, residual = block(x)

# Stack multiple blocks
residual = None
for block in blocks:
    hidden_states, residual = block(hidden_states, residual=residual)

Architecture Integration

The Block class is designed to be stacked in sequence to build deep architectures:
class MyModel(nn.Module):
    def __init__(self, dim=768, n_layers=24):
        super().__init__()
        self.layers = nn.ModuleList([
            Block(
                dim=dim,
                mixer_cls=lambda d: MHA(embed_dim=d, num_heads=12),
                mlp_cls=lambda d: GatedMLP(in_features=d),
            )
            for _ in range(n_layers)
        ])
    
    def forward(self, x):
        residual = None
        for layer in self.layers:
            x, residual = layer(x, residual=residual)
        return x

Notes

  • The residual connection pattern (Add → LN → Mixer) differs from standard prenorm Transformers to enable operator fusion
  • When fused_add_norm=True, the implementation uses optimized Triton kernels for better performance on GPUs
  • Some mixers (e.g., S4, S4D) may return (output, state) tuples; the Block automatically handles this by extracting the output
  • The MLP is applied after the mixer with its own normalization layer when mlp_cls is not nn.Identity

See Also

  • MHA - Multi-head attention mixer
  • GatedMLP - Gated MLP implementation
  • LRNN - Linear recurrent neural network mixer

Build docs developers (and LLMs) love