Skip to main content

TopKRouter

Top-k gating distribution for routing tokens to experts. Implements the routing mechanism from Shazeer et al. (2017) that computes gating scores and selects the top-k experts for each token.

Mathematical formulation

router_logits = W_router · h
top_k_scores, top_k_indices = TopK(softmax(router_logits), k)

Constructor

from modern_llm.models.moe import TopKRouter
from modern_llm.config.model_config import MoEConfig

moe_config = MoEConfig(
    num_experts=8,
    top_k=2,
    capacity_factor=1.25
)
router = TopKRouter(dim=768, config=moe_config)
dim
int
required
Input dimension matching the hidden states. Must be positive.
config
MoEConfig
required
MoE configuration containing num_experts, top_k, and other routing parameters.

Attributes

dim
int
Input feature dimension.
config
MoEConfig
MoE configuration object.
router
nn.Linear
Linear layer projecting from dim to num_experts (no bias).

forward

def forward(self, hidden_states: Tensor) -> Tuple[Tensor, Tensor]
Compute top-k routing decisions for each token.
hidden_states
Tensor
required
Input tensor of shape (batch, seq_len, dim).

Returns

topk_scores
Tensor
Top-k gating scores of shape (batch, seq_len, top_k). Values sum to 1 per token.
topk_indices
Tensor
Expert indices of shape (batch, seq_len, top_k) with values in range [0, num_experts).

Status

The routing logic will be implemented in a future phase. Currently raises NotImplementedError.

MixtureOfExperts

Sparse mixture of experts layer that routes tokens to specialized sub-networks. Follows the Switch Transformer architecture (Fedus et al., 2021) where each token is processed by only top-k experts rather than all experts, enabling efficient model scaling.

Architecture

For each token:
  1. Router computes gating scores for all experts
  2. Select top-k experts based on scores
  3. Process token through selected experts
  4. Combine expert outputs weighted by routing scores

Constructor

from modern_llm.models.moe import MixtureOfExperts
from modern_llm.config.model_config import MoEConfig

moe_config = MoEConfig(
    num_experts=8,
    top_k=2,
    capacity_factor=1.25,
    expert_capacity=None
)
moe = MixtureOfExperts(dim=768, moe_config=moe_config)
dim
int
required
Model dimension for input/output. Must be positive.
moe_config
MoEConfig
required
MoE configuration containing:
  • num_experts: Number of expert networks
  • top_k: Number of experts to activate per token
  • capacity_factor: Expert capacity as multiple of average tokens per expert
  • expert_capacity: Optional fixed capacity per expert

Attributes

dim
int
Model dimension.
moe_config
MoEConfig
MoE configuration object.
router
TopKRouter
Token routing module that selects top-k experts.
experts
nn.ModuleList
List of num_experts feedforward networks. Each expert is a simple 2-layer MLP with GELU activation:
  • Linear: dim -> dim * 4
  • GELU activation
  • Linear: dim * 4 -> dim

forward

def forward(self, hidden_states: Tensor) -> Tensor
Route tokens through experts and combine outputs.
hidden_states
Tensor
required
Input tensor of shape (batch, seq_len, dim).

Returns

output
Tensor
Expert-processed outputs of shape (batch, seq_len, dim).

Status

The forward pass with routing logic will be implemented in a future phase. Currently raises NotImplementedError.

Example usage (when implemented)

import torch
from modern_llm.models.moe import MixtureOfExperts
from modern_llm.config.model_config import MoEConfig

# Configure MoE layer
moe_config = MoEConfig(
    num_experts=8,    # 8 expert networks
    top_k=2,          # Activate 2 experts per token
    capacity_factor=1.25
)
moe = MixtureOfExperts(dim=768, moe_config=moe_config)

# Process hidden states
hidden_states = torch.randn(2, 128, 768)
output = moe(hidden_states)  # Will work when implemented

Integration with transformer

from modern_llm.config.model_config import ModernLLMConfig, MoEConfig
from modern_llm.models.transformer import ModernDecoderLM

# Enable MoE in decoder blocks
moe_config = MoEConfig(
    num_experts=8,
    top_k=2,
    capacity_factor=1.25
)

config = ModernLLMConfig(
    vocab_size=50257,
    d_model=768,
    n_layers=12,
    n_heads=12,
    use_moe=True,
    moe_config=moe_config
)

# Transformer uses MoE instead of SwiGLU in feedforward
model = ModernDecoderLM(config)

Benefits of MoE

  • Efficient scaling: Increase model capacity without proportional compute increase
  • Sparse computation: Each token uses only top-k experts (e.g., 2 out of 8)
  • Specialization: Different experts can specialize in different types of content
  • Memory efficiency: Experts can be distributed across devices

Complexity

O(top_k · dim · 4·dim) = O(top_k · dim²) per token, compared to O(dim²) for dense feedforward when top_k < num_experts / 4. With 8 experts and top_k=2, each token processes through only 25% of the parameters compared to a dense layer with equivalent total parameters.

Build docs developers (and LLMs) love