Skip to main content

Overview

The MLP module provides multi-layer perceptron implementations for use in neural architectures. The primary implementation is GatedMLP, which uses a gating mechanism with activation functions for improved expressiveness.

GatedMLP

A gated multi-layer perceptron that splits the hidden representation into two paths: a value path and a gate path, then multiplies them together.

Architecture

        ┌─────────────┐
        │   Input     │
        └──────┬──────┘

        ┌──────▼──────────┐
        │  fc1 (2×hidden) │
        └──────┬──────────┘

        ┌──────▼──────┐
        │ chunk(2)    │
        └──┬────────┬─┘
           │        │
      ┌────▼───┐ ┌──▼─────┐
      │ value  │ │  gate  │
      └────┬───┘ └──┬─────┘
           │        │
           │   ┌────▼──────┐
           │   │activation │
           │   └────┬──────┘
           │        │
        ┌──▼────────▼─┐
        │  multiply   │
        └──────┬──────┘

        ┌──────▼──────┐
        │     fc2     │
        └──────┬──────┘

        ┌──────▼──────┐
        │   Output    │
        └─────────────┘

Class Definition

from lrnnx.layers.mlp import GatedMLP

mlp = GatedMLP(
    in_features=768,
    hidden_features=2048,
    out_features=768,
    activation=F.silu,
    bias=False,
    multiple_of=128
)

Parameters

in_features
int
required
Number of input features.
hidden_features
int
default:"None"
Number of hidden features in the MLP. If None, defaults to int(8 * in_features / 3), which is commonly used in Transformer models (approximately 2.67× expansion).
out_features
int
default:"None"
Number of output features. If None, uses in_features for a residual-compatible architecture.
activation
callable
default:"F.silu"
Activation function to apply to the gate path. Common choices:
  • F.silu (Swish): Smooth, non-monotonic activation
  • F.gelu: Gaussian Error Linear Unit
  • F.relu: Rectified Linear Unit
bias
bool
default:"False"
Whether to include bias terms in the linear layers. Typically set to False when using layer normalization before the MLP.
multiple_of
int
default:"128"
Round hidden_features up to be a multiple of this value for optimal hardware utilization. Common values are 128 or 256 for GPU efficiency.
device
torch.device
default:"None"
Device to place tensors on (e.g., torch.device('cuda')).
dtype
torch.dtype
default:"None"
Data type for tensors (e.g., torch.float16, torch.bfloat16).

Methods

forward

output = mlp.forward(x)
Forward pass through the gated MLP.
Parameters
x
torch.Tensor
required
Input tensor of shape (..., in_features). Can handle arbitrary batch dimensions.
Returns
output
torch.Tensor
Output tensor of shape (..., out_features).

Usage Examples

Basic Usage

import torch
from lrnnx.layers.mlp import GatedMLP
import torch.nn.functional as F

# Create a gated MLP
mlp = GatedMLP(
    in_features=768,
    hidden_features=2048,
    activation=F.silu
)

# Forward pass
x = torch.randn(2, 128, 768)  # (batch, seq_len, features)
output = mlp(x)  # (2, 128, 768)

With Default Hidden Size

# Let the module compute hidden_features automatically
mlp = GatedMLP(
    in_features=768,
    # hidden_features defaults to int(8 * 768 / 3) = 2048
)

output = mlp(x)

Different Activation Functions

import torch.nn.functional as F

# SiLU activation (default, smooth)
mlp_silu = GatedMLP(in_features=768, activation=F.silu)

# GELU activation
mlp_gelu = GatedMLP(in_features=768, activation=F.gelu)

# ReLU activation
mlp_relu = GatedMLP(in_features=768, activation=F.relu)

Custom Output Dimension

# Project to different output dimension
mlp_proj = GatedMLP(
    in_features=768,
    hidden_features=2048,
    out_features=512  # Different output size
)

x = torch.randn(2, 768)
output = mlp_proj(x)  # Shape: (2, 512)

Optimized for Hardware

# Round hidden dimension for GPU efficiency
mlp_optimized = GatedMLP(
    in_features=768,
    hidden_features=2000,  # Will be rounded to 2048
    multiple_of=256  # Round to nearest 256
)

# Check actual hidden dimension
print(mlp_optimized.fc1.out_features)  # 4096 (2 × 2048 for gating)

Integration with Block

from lrnnx.layers.block import Block
from lrnnx.layers.mha import MHA
from lrnnx.layers.mlp import GatedMLP
import torch.nn as nn

# Use GatedMLP in a Block
block = Block(
    dim=768,
    mixer_cls=lambda dim: MHA(
        embed_dim=dim,
        num_heads=12,
        causal=True
    ),
    mlp_cls=lambda dim: GatedMLP(
        in_features=dim,
        activation=F.silu
    ),
    norm_cls=nn.LayerNorm
)

x = torch.randn(2, 128, 768)
hidden_states, residual = block(x)

No MLP (Identity)

# Create block without MLP by passing nn.Identity
block_no_mlp = Block(
    dim=768,
    mixer_cls=SomeMixerClass,
    mlp_cls=nn.Identity  # No MLP applied
)

Implementation Details

Gating Mechanism

The gated MLP uses a multiplicative gating mechanism:
  1. Project input to 2 × hidden_features dimensions
  2. Split into two equal parts: value and gate
  3. Apply activation to gate path
  4. Multiply value by activated gate
  5. Project back to output dimension
Mathematically:
y = fc2((value ⊙ activation(gate)))
where [value, gate] = chunk(fc1(x))
This gating allows the network to control information flow dynamically.

Hidden Dimension Calculation

The default hidden dimension formula int(8 * in_features / 3) gives approximately 2.67× expansion:
  • For 768 features: 8 × 768 / 3 = 2048
  • For 1024 features: 8 × 1024 / 3 = 2730 → rounded to 2816 (with multiple_of=128)
This is based on common practice in Transformer architectures.

Memory Alignment

The multiple_of parameter ensures hidden dimensions are multiples of specified values (typically 128 or 256) for optimal GPU memory access patterns and tensor core utilization.

Performance Considerations

  • Gating vs. Standard MLP: Gated MLPs typically provide better performance with similar parameter counts
  • Activation Choice:
    • SiLU/Swish: Smooth, often better gradient flow
    • GELU: Similar to SiLU, common in BERT-style models
    • ReLU: Fastest but may have dying neuron issues
  • Memory Alignment: Always use multiple_of for production models to ensure optimal GPU utilization
  • Bias Terms: When using layer normalization, bias can often be disabled (bias=False) for efficiency

Notes

  • The gating mechanism doubles the number of parameters in fc1 compared to a standard MLP
  • Hidden features are automatically rounded up to the nearest multiple of multiple_of
  • The activation function is only applied to the gate path, not the value path
  • Works with arbitrary input tensor shapes, not just 2D or 3D tensors

See Also

  • Block - Wrapper for using MLPs in layer architectures
  • MHA - Multi-head attention with optional integrated MLP

Build docs developers (and LLMs) love