Skip to main content

Overview

Expert Parallelism (EP) distributes expert weights across multiple devices in Mixture-of-Experts (MoE) models. This addresses memory bottlenecks for large-scale MoE models where tokens are dynamically routed to specialized experts across GPUs.

Key Benefits

  • Reduced memory footprint per GPU by sharding expert weights
  • Higher throughput with optimized all-to-all communication
  • Better scalability for models with 100+ experts
  • Load balancing to minimize GPU utilization variance

When to Use Expert Parallelism

Use EP for:
  • Mixture-of-Experts models (DeepSeek, Mixtral, Qwen-MoE)
  • Models with 64+ experts that don’t fit on a single GPU
  • Large-scale deployments requiring maximum throughput
Typical EP models:
  • DeepSeek-V2, DeepSeek-V3, DeepSeek-R1
  • Mixtral-8x7B, Mixtral-8x22B
  • Qwen2-57B-A14B, Qwen3-235B-A22B

Architecture

How EP Works

In a typical MoE layer with EP:
  1. Token Routing: Each token is routed to top-K experts based on gating scores
  2. All-to-All Dispatch: Tokens are shuffled across GPUs to their assigned experts
  3. Expert Computation: Each GPU processes its local expert subset
  4. All-to-All Combine: Results are gathered back to original token positions
GPU 0 (Experts 0-63)    GPU 1 (Experts 64-127)   GPU 2 (Experts 128-191)   GPU 3 (Experts 192-255)
       ↓                        ↓                         ↓                          ↓
   All-to-All Dispatch (shuffle tokens to assigned experts)
       ↓                        ↓                         ↓                          ↓
Local Expert Computation
       ↓                        ↓                         ↓                          ↓
   All-to-All Combine (gather results back)

Configuration

Basic Setup

python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 \
  --ep 8 \
  --moe-a2a-backend deepep \
  --moe-runner-backend deep_gemm
Key parameters:
  • --tp: Tensor parallel size (intra-node parallelism)
  • --ep: Expert parallel size (typically equals tp)
  • --moe-a2a-backend: All-to-all communication backend
  • --moe-runner-backend: Expert computation backend

Multi-Node Setup

# Node 0
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 16 --ep 16 \
  --nnodes 2 --node-rank 0 \
  --dist-init-addr <MASTER_NODE_IP>:29500 \
  --moe-a2a-backend deepep \
  --moe-runner-backend deep_gemm

# Node 1
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 16 --ep 16 \
  --nnodes 2 --node-rank 1 \
  --dist-init-addr <MASTER_NODE_IP>:29500 \
  --moe-a2a-backend deepep \
  --moe-runner-backend deep_gemm

Communication Backends

All-to-All Backends (--moe-a2a-backend)

BackendDescriptionUse CaseConstraints
none (default)Uses All-Reduce/All-GatherHybrid EP+TP (ep < tp)
deepepDeepEP communication libraryLarge-scale EP deploymentsep == tp
mooncakeElastic inference with RDMAElastic EP servingep == tp
moriAMD ROCm-optimized all-to-allAMD GPU deploymentsep == tp
flashinferFlashInfer all-to-allLarge-scale EP
ascend_fuseepAscend NPU fused operatorAscend NPU (decode only)ep == tp

DeepEP Dispatch Modes

DeepEP supports two dispatch modes:
  • normal: Optimized for prefill workloads (high throughput)
  • low_latency: Optimized for decode workloads (low latency, CUDA Graph compatible)
Recommended setup:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep \
  --deepep-mode auto  # Automatically switches between modes

MoE Runner Backends (--moe-runner-backend)

BackendDescriptionBest For
auto (default)Auto-selects based on hardware/modelGeneral use
deep_gemmDeepGEMM optimized GEMMsFP8 block-wise quantization
tritonTriton-based grouped GEMMsCustom kernel development
cutlassCUTLASS-based GEMMsNVIDIA architectures
flashinfer_trtllmFlashInfer + TensorRT-LLMBlackwell with TRT-LLM
flashinfer_cutlassFlashInfer + CUTLASSBlackwell with FP4/FP8
flashinfer_mxfp4FlashInfer MXFP4 variantMXFP4 models
flashinfer_cutedslFlashInfer with custom DSLNVFP4 models

Advanced Features

Two-Batch Overlap (TBO)

TBO splits requests into micro-batches, interleaving attention with dispatch/combine operations:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep \
  --enable-two-batch-overlap
Benefits:
  • Up to 2× throughput improvement
  • Hides communication latency behind computation
  • No peak memory increase
Implementation:
operations = [
    self._forward_attn,
    YieldOperation(),  # Overlap with dispatch of prior micro-batch
    self._forward_dispatch,
    self._forward_mlp,
    YieldOperation(),  # Overlap with combine
    self._forward_combine,
]
Details: Large-Scale EP Blog - TBO Section

Single-Batch Overlap (SBO)

SBO enables overlapping operations within a single batch (e.g., shared experts with communication):
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep \
  --enable-single-batch-overlap
Uses dispatcher-hook system for modularity. See PR #13327.

Expert Parallelism Load Balancer (EPLB)

EPLB addresses routing imbalances by analyzing expert activation statistics:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep \
  --enable-eplb
How it works:
  1. Collects expert activation statistics during inference
  2. Computes optimal expert arrangement to minimize variance
  3. Strategically places or replicates experts across GPUs
  4. Reduces idle cycles and improves load balance
Tuning:
  • Increase batch sizes for stable statistics
  • Configure periodic rebalancing (e.g., every 1000 requests)
  • Monitor load balancedness ratio (mean/max computation time)
Details: EPLB Repository

Hardware-Specific Configuration

NVIDIA GPUs

Standard setup:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep \
  --moe-runner-backend auto
Blackwell (B100/B200) with FP4:
python -m sglang.launch_server \
  --model-path nvidia/DeepSeek-R1-0528-NVFP4-v2 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep \
  --moe-runner-backend flashinfer_trtllm

AMD GPUs (ROCm)

python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend mori \
  --deepep-mode normal
Note: MORI backend only supports normal mode currently.

Huawei Ascend NPUs

Prefill instance:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --disaggregation-mode prefill \
  --tp 16 --ep 16 \
  --moe-a2a-backend deepep \
  --deepep-mode normal
Decode instance:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --disaggregation-mode decode \
  --tp 16 --ep 16 \
  --moe-a2a-backend ascend_fuseep \
  --deepep-mode low_latency
DeepEP Ant-moving Function (for long sequences on Ascend):
# Enable ant-moving for dispatch and combine
export DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=8192
export DEEPEP_NORMAL_LONG_SEQ_ROUND=16  # 8192 * 16 = 128K tokens
export DEEPEP_NORMAL_COMBINE_ENABLE_LONG_SEQ=1
export HCCL_BUFFSIZE=256  # Must be sufficient for buffer size

python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 16 --ep 16 \
  --moe-a2a-backend deepep
Buffer size calculation:
# With ant-moving enabled
HCCL_BUFFSIZE >= 2 * (102 + 4 + PER_ROUND_TOKENS * (hidden_size + hidden_size + hidden_size) * topk) + 20

# Without ant-moving
HCCL_BUFFSIZE >= 2 * (102 + 4 + TOTAL_SEQ_LEN * (hidden_size + hidden_size) * topk) + 20

Combining with Other Parallelism

EP + TP

Most common combination:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep

EP + DPA (Data Parallelism Attention)

For MLA-based MoE models like DeepSeek:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --dp-size 8 \
  --enable-dp-attention \
  --moe-a2a-backend deepep \
  --moe-runner-backend deep_gemm
See Data Parallelism for DPA details.

EP + PP (Pipeline Parallelism)

For very large models:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3.1 \
  --tp 8 --ep 8 \
  --pp-size 4 \
  --nnodes 4 --node-rank 0 \
  --moe-a2a-backend deepep \
  --chunked-prefill-size 4096

EP + Speculative Decoding

For speculative decoding with different precisions:
python -m sglang.launch_server \
  --model-path nvidia/DeepSeek-R1-0528-NVFP4-v2 \
  --tp 8 --ep 8 \
  --moe-runner-backend flashinfer_trtllm \
  --speculative-moe-runner-backend triton  # Draft uses BF16, target uses FP4

Performance Tuning

python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep \
  --moe-runner-backend deep_gemm \
  --deepep-mode auto \
  --enable-two-batch-overlap \
  --enable-eplb \
  --mem-fraction-static 0.85

Tuning Triton Backend

For custom kernel optimization:
# Generate tuned configurations
cd benchmark/kernels/fused_moe_triton
python benchmark.py --model deepseek-ai/DeepSeek-V3

python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-runner-backend triton
See Triton MoE Tuning Guide.

Extending the EP Framework

SGLang’s EP framework is highly modular and extensible:

Architecture

[input_hidden_states]

   TopK.forward → select experts

   [TopKOutput]

  FusedMoE.forward

  Dispatcher.dispatch → DeepEP / bypass

  [DispatchOutput]

  quant_method.apply → MoeRunner.forward

  pre-permute + grouped_gemm + post-permute

  [CombineInput]

  Dispatcher.combine → DeepEP / bypass

[final_hidden_states]

Adding New Backends

For new all-to-all dispatcher:
  1. Implement BaseDispatcher subclass with dispatch and combine methods
  2. Register via --moe-a2a-backend
For new MoE runner:
  1. Define MoeRunnerCore subclass for grouped GEMMs
  2. Register permute methods:
    • Fused mode (static, torch.compile-compatible): register_fused_func
    • Permute mode (dynamic): register_pre_permute and register_post_permute
  3. Register via --moe-runner-backend
See:

Troubleshooting

Communication Backend Not Working

Symptom: Error initializing DeepEP/Mooncake Solution: Check backend constraints:
# DeepEP requires ep == tp
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep

# For hybrid EP+TP (ep < tp), use 'none' backend
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 4 \
  --moe-a2a-backend none

Poor Load Balance

Symptom: High variance in GPU utilization Solution: Enable EPLB and increase batch size:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --enable-eplb \
  --max-running-requests 128

Low Throughput

Symptom: Lower than expected throughput Solution: Enable overlap optimizations:
python -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3 \
  --tp 8 --ep 8 \
  --moe-a2a-backend deepep \
  --deepep-mode auto \
  --enable-two-batch-overlap \
  --enable-single-batch-overlap

Best Practices

  1. Set ep == tp for DeepEP/Mooncake backends
  2. Use --deepep-mode auto for automatic dispatch mode switching
  3. Enable TBO for maximum throughput (up to 2× improvement)
  4. Enable EPLB with large batch sizes for better load balance
  5. Monitor expert activation patterns to understand routing behavior
  6. Combine with DPA for MLA-based MoE models

Configuration Summary

ParameterDescriptionDefaultRecommended
--epExpert parallel size1Same as --tp
--moe-a2a-backendAll-to-all backendnonedeepep
--moe-runner-backendMoE computation backendautoauto or deep_gemm
--deepep-modeDeepEP dispatch modenormalauto
--enable-two-batch-overlapEnable TBOFalseEnable for throughput
--enable-eplbEnable load balancerFalseEnable with large batches