Skip to main content

RMSNorm

Root Mean Square Layer Normalization (Zhang & Sennrich, 2019). RMSNorm is a simplified normalization layer that scales activations by their RMS (root mean square) rather than using mean-centered variance like LayerNorm. This reduces computation while maintaining training stability.

Mathematical formulation

y = x * γ / sqrt(mean(x²) + ε)
where γ is a learned weight vector.

Constructor

from modern_llm.models.layers import RMSNorm

layer_norm = RMSNorm(hidden_dim=768, eps=1e-5)
hidden_dim
int
required
Dimension of the input vectors. Must be positive.
eps
float
default:"1e-5"
Small constant added to denominator for numerical stability. Must be positive.

Attributes

hidden_dim
int
Input dimension size.
eps
float
Epsilon value for numerical stability.
weight
nn.Parameter
Learnable scale parameter γ of shape (hidden_dim,). Initialized to ones.

forward

def forward(self, x: Tensor) -> Tensor
Normalize input by RMS and scale by learned weights.
x
Tensor
required
Input tensor of shape (…, hidden_dim). The last dimension must match hidden_dim.

Returns

output
Tensor
Normalized tensor with same shape as input.

Example

import torch
from modern_llm.models.layers import RMSNorm

# Create normalization layer
norm = RMSNorm(hidden_dim=768, eps=1e-5)

# Normalize activations
x = torch.randn(2, 128, 768)
y = norm(x)
print(y.shape)  # torch.Size([2, 128, 768])

# Works with any shape ending in hidden_dim
x = torch.randn(4, 32, 16, 768)
y = norm(x)
print(y.shape)  # torch.Size([4, 32, 16, 768])

Complexity

O(hidden_dim) per token - linear in the feature dimension.

SwiGLU

SwiGLU feedforward network (Shazeer, 2020; Chowdhery et al., 2022). SwiGLU combines the Swish activation function with a gating mechanism (GLU - Gated Linear Unit). It has been shown to improve model quality compared to traditional GELU or ReLU feedforward networks.

Mathematical formulation

SwiGLU(x) = W_o[(W_g x) ⊙ swish(W_v x)]
where:
  • W_g and W_v are the gate and value projections
  • ⊙ is element-wise multiplication
  • swish(x) = x · sigmoid(x)

Constructor

from modern_llm.models.layers import SwiGLU

ffn = SwiGLU(
    in_features=768,
    hidden_features=3072,
    out_features=768,
    bias=True
)
in_features
int
required
Input dimension. Must be positive.
hidden_features
int
required
Hidden dimension for intermediate projections. Must be positive. Typically 4x the in_features.
out_features
Optional[int]
default:"None"
Output dimension. Defaults to in_features if not specified.
bias
bool
default:"True"
Whether to include bias terms in linear layers.

Attributes

in_features
int
Input dimension size.
hidden_features
int
Hidden dimension for gate/value projections.
out_features
int
Output dimension size.
gate
nn.Linear
Combined gate and value projection: in_features -> (hidden_features * 2).
proj
nn.Linear
Output projection: hidden_features -> out_features.

forward

def forward(self, x: Tensor) -> Tensor
Apply SwiGLU feedforward transformation.
x
Tensor
required
Input tensor of shape (…, in_features).

Returns

output
Tensor
Output tensor of shape (…, out_features).

Example

import torch
from modern_llm.models.layers import SwiGLU

# Standard feedforward network (4x expansion)
ffn = SwiGLU(
    in_features=768,
    hidden_features=3072,  # 4x expansion
    out_features=768
)

x = torch.randn(2, 128, 768)
output = ffn(x)
print(output.shape)  # torch.Size([2, 128, 768])

# Custom dimensions
ffn = SwiGLU(
    in_features=512,
    hidden_features=2048,
    out_features=1024,
    bias=False
)

x = torch.randn(4, 64, 512)
output = ffn(x)
print(output.shape)  # torch.Size([4, 64, 1024])

Usage in transformer blocks

from modern_llm.models.layers import RMSNorm, SwiGLU
import torch.nn as nn

class TransformerFFN(nn.Module):
    def __init__(self, d_model=768, dropout=0.1):
        super().__init__()
        self.norm = RMSNorm(d_model)
        self.ffn = SwiGLU(
            in_features=d_model,
            hidden_features=d_model * 4,
            out_features=d_model
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Pre-norm residual connection
        residual = x
        x = self.norm(x)
        x = self.ffn(x)
        x = self.dropout(x)
        return residual + x

Complexity

O(in_features · hidden_features) per token due to the linear projections.

Build docs developers (and LLMs) love