Skip to main content
CUDA Graph is a powerful optimization technique that dramatically reduces kernel launch overhead by capturing a sequence of GPU operations and replaying them as a single unit. SGLang supports multiple CUDA graph implementations to accelerate different phases of LLM inference.

Why CUDA Graph?

In LLM inference, the model forward pass consists of hundreds or thousands of small kernel launches. Each kernel launch has overhead:
  • CPU-GPU synchronization
  • Kernel parameter setup
  • GPU scheduler overhead
CUDA Graph addresses this by:
  1. Capturing a sequence of GPU operations with fixed shapes and memory addresses into a graph
  2. Replaying the entire graph with a single launch, dramatically reducing overhead
  3. Improving GPU utilization through more compact scheduling
The benefits are most significant for:
  • Small batch sizes (where launch overhead is proportionally larger)
  • Decode phase (naturally has fixed batch size)
  • Models with many layers and fragmented operators

CUDA Graph for Decode

Overview

Decode phase has a natural advantage for CUDA graph: the batch size is relatively stable, and each token generates exactly one new token. SGLang automatically enables CUDA graph for decode by default.

How It Works

  1. Initialization: Pre-allocate static buffers for various batch sizes
  2. Capture: For each batch size, run a dummy forward pass while recording GPU operations
  3. Replay: At runtime, copy inputs into static buffers and replay the captured graph

Configuration

--cuda-graph-max-bs
integer
Maximum batch size for which to capture CUDA graphs. By default, CUDA graph is enabled for small batch sizes (e.g., less than 160 or 256). For some models, especially at large tensor parallelism sizes, CUDA graph can be useful for batch sizes up to 512 or 768.
CUDA graph consumes additional GPU memory. If you increase --cuda-graph-max-bs, you may need to reduce --mem-fraction-static to prevent OOM errors.

Example

python3 -m sglang.launch_server \
  --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
  --cuda-graph-max-bs 512 \
  --mem-fraction-static 0.85

Piecewise CUDA Graph (PCG)

Overview

Standard CUDA graphs capture the entire model forward pass as a single graph, which works well for decode (fixed batch size) but not for extend/prefill where the number of tokens varies across iterations. Piecewise CUDA Graph (PCG) solves this by:
  • Splitting the model’s computation graph into pieces (roughly one per layer) at “split points” (e.g., MoE dispatch ops)
  • Capturing each piece as a separate CUDA graph for a set of pre-defined token lengths
  • At runtime, padding the input to the nearest captured size and replaying each piece
This eliminates kernel launch overhead for prefill/extend while still supporting dynamic shapes.
PCG is enabled by default for supported configurations. The old --enable-piecewise-cuda-graph flag is deprecated. Use --disable-piecewise-cuda-graph to turn it off.

Usage

PCG is enabled by default. No extra flags needed:
python3 -m sglang.launch_server \
  --model-path meta-llama/Meta-Llama-3.1-8B-Instruct

Disable PCG

python3 -m sglang.launch_server \
  --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
  --disable-piecewise-cuda-graph

Custom Capture Sizes

python3 -m sglang.launch_server \
  --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
  --piecewise-cuda-graph-max-tokens 2048

Configuration Parameters

--disable-piecewise-cuda-graph
boolean
default:"false"
Disable PCG for extend/prefill.
--enforce-piecewise-cuda-graph
boolean
default:"false"
Force-enable PCG, skipping all auto-disable conditions. For testing only.
--piecewise-cuda-graph-max-tokens
integer
Maximum token count to capture. Defaults to chunked_prefill_size (non-MLA) or 2048 (MLA).
--piecewise-cuda-graph-tokens
list
Explicit list of token lengths to capture. Auto-generated if not set.
--piecewise-cuda-graph-compiler
string
default:"eager"
Compiler backend for the captured subgraphs. Choices: eager, inductor.

How It Works

Torch Compile Backend

PCG uses torch.compile with a custom backend (SGLangBackend) to split and compile the model’s forward pass:
model.forward wrapper
→ torch.compile(..., backend=SGLangBackend)
→ FX graph
→ split_graph() at registered split ops
→ split_gm (top-level graph that chains the pieces)
→ replace capturable submodules with CUDAPiecewiseBackend
→ runtime dispatch: eager split ops + per-piece capture/replay
Install: Replaces model.forward with a wrapper function that dispatches to compiled callable when PCG is active Split: SGLangBackend receives the FX graph and cuts it at split points (attention ops, all-reduce ops, etc.) Replace: Each capturable submodule is compiled and replaced with a CUDAPiecewiseBackend instance Dispatch: At runtime, split-op submodules run eagerly, while CUDAPiecewiseBackend submodules go through:
  • Compile warmup — runs the general-shape compiled path
  • Capture — for each capture size, runs one warmup pass then records a CUDA graph
  • Steady-state replay — replays the captured CUDA graph

Piecewise CUDA Graph Runner

PiecewiseCudaGraphRunner orchestrates the full lifecycle:
  • Compile: Warms up JIT kernels, wraps model with torch.compile, triggers Dynamo tracing
  • Capture: Iterates over capture sizes in reverse order (largest first), captures CUDA graphs
  • Replay: At runtime, finds smallest captured size >= actual token count, copies inputs with zero-padding, replays graphs, slices outputs

Shape Configuration

The default capture schedule is auto-generated with increasing granularity:
Token rangeStep size
4 – 324
48 – 25616
288 – 51232
576 – 102464
1280 – 4096256
4096+512
Sizes are capped at --piecewise-cuda-graph-max-tokens. If the token count exceeds the largest captured size, the runtime falls back to the normal (non-graph) forward path.

Memory Optimization

The memory cost comes from two parts:
  1. Torch memory allocator: Trivial overhead thanks to:
    • Global shared memory pool reused across all runners and capture sizes
    • Reverse-order capture (large to small) allows smaller graphs to reuse memory
    • Output tensors stored as weak references
  2. Non-torch memory: CUDA graph objects require GPU memory to store recorded kernel launch parameters. This scales with the number of captured sizes, which is why piecewise_cuda_graph_max_tokens is capped conservatively.

Compatibility

PCG is auto-disabled in the following scenarios:
  • Disabled model architectures (e.g., DeepseekV32ForCausalLM)
  • Speculative decoding
  • DP attention
  • Pipeline parallelism (pp_size > 1)
  • Non-CUDA hardware (AMD ROCm, Ascend NPU)
  • MoE A2A backend
  • LoRA
  • Multimodal / VLM models
  • DLLM (diffusion LLM)
  • Deterministic inference
  • PD disaggregation
  • Expert distribution recorder / EPLB
Use --enforce-piecewise-cuda-graph to skip all auto-disable checks (for testing/debugging only).

Bug Report

PCG is enabled by default but is still experimental. Since PCG relies on torch.compile to trace the model’s forward pass, most bugs are introduced by torch compile tracing failures (e.g., untraceable ops, dynamic control flow, or graph breaks).
If you encounter errors during server startup:
Piecewise CUDA Graph is enabled by default as an experimental feature.
To work around this error, add --disable-piecewise-cuda-graph to your launch command.
Please report this issue at https://github.com/sgl-project/sglang/issues/new/choose
To work around: add --disable-piecewise-cuda-graph to your launch command. When filing a bug report, please include:
  1. Full error traceback
  2. Model name and quantization method
  3. Launch command with all arguments
  4. GPU type and driver version

Developer Guide: Making Kernels Compatible

Since PCG relies on torch.compile, newly developed CUDA kernels are typically not compatible out of the box. To make a kernel compatible, register it as a custom op:
from sglang.srt.utils.custom_op import register_custom_op

# Inplace operator (no return value)
@register_custom_op(mutates_args=["output_q", "output_s"])
def per_token_group_quant_8bit(
    input: torch.Tensor,
    output_q: torch.Tensor,
    output_s: torch.Tensor,
) -> None:
    # kernel implementation ...

# Operator with output
@register_custom_op(mutates_args=["x"], out_shape=0)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x.add_(y)
For wrapping external library functions (e.g., FlashInfer kernels), use register_custom_op_from_extern. See python/sglang/srt/utils/custom_op.py for full API documentation.

CUDA Graph for Vision Transformers

Overview

In multimodal reasoning services, the visual encoder (ViT / Vision Transformer) typically has:
  • Many layers with fragmented operators (LN, QKV projections, attention, MLP, residuals)
  • Extremely frequent kernel launches
  • Server-side small batch / low latency scenarios where kernel launch overhead is significant
  • Variable input token count (different image/video resolutions)
CUDA Graph for ViT captures the “blocks + merger + deepstack merger (optional)” part of a vision transformer and replays it for identical shapes.

Usage

Enable CUDA Graph for ViT by setting the environment variable:
SGLANG_VIT_ENABLE_CUDA_GRAPH=1 \
python3 -m sglang.launch_server \
  --model Qwen/Qwen3-VL-8B-Instruct
Or combine with Piecewise CUDA Graph:
SGLANG_VIT_ENABLE_CUDA_GRAPH=1 \
python3 -m sglang.launch_server \
  --model Qwen/Qwen3-VL-8B-Instruct \
  --piecewise-cuda-graph-max-tokens 4096 \
  --piecewise-cuda-graph-compiler eager

Design Considerations

Dynamic inputs to fit static constraints:
  • Build a graph cache by sequence length S (graph_key = S)
  • First time creates and captures a new graph; afterwards replays it
  • Many distinct S values increase VRAM usage for graph-private memory pools
Stable addresses:
  • Everything parameter-like becomes a static buffer (block_input, block_ws, block_output, cu_full_len, sin_cos_ws)
  • During replay, tensor contents are modified but tensors are not swapped
Attention backend arguments:
  • Arguments are fixed inside the graph (cu_seqlens, max_len)
  • For the same graph_key = S, requires identical segmentation pattern in cu_seqlens
Rotary buffer management:
  • Reallocates larger sin_cos_ws when seq_len increases
  • max_content_len ensures maximum size of allocated rotary buffer

Supported Models

  • Qwen2.5-VL
  • Qwen3-VL

Best Practices

Decode CUDA Graph

Increase --cuda-graph-max-bs for large TP sizes, but monitor memory usage and adjust --mem-fraction-static accordingly.

Piecewise CUDA Graph

Let PCG auto-enable. Only disable if you encounter issues. Report bugs to help improve the feature.

Memory Management

CUDA graphs consume additional memory. Balance between CUDA graph coverage and KV cache pool size.

Custom Kernels

Register custom ops with @register_custom_op to make kernels compatible with PCG.

Performance Impact

Decode Phase

CUDA graph for decode typically provides:
  • 10-30% speedup for small batch sizes (1-32)
  • 5-15% speedup for medium batch sizes (32-128)
  • Diminishing returns for large batch sizes (>256)
The benefits are most pronounced on:
  • Hopper and Blackwell GPUs (more efficient graph execution)
  • Models with many small kernels
  • High tensor parallelism (more communication ops to capture)

Prefill Phase (PCG)

Piecewise CUDA graph for prefill typically provides:
  • 15-40% speedup for short sequences (64-512 tokens)
  • 10-25% speedup for medium sequences (512-2048 tokens)
  • 5-15% speedup for long sequences (2048-4096 tokens)
The benefits are most significant for:
  • Small batch sizes where kernel launch overhead dominates
  • Models with many layers
  • Frequent prefill operations (e.g., chatbot workloads)

Vision Transformers

CUDA graph for ViT typically provides:
  • 20-50% speedup for small batches
  • Larger speedups on Hopper/Blackwell GPUs
  • Most beneficial when serving low-latency multimodal workloads

Troubleshooting

Out of Memory (OOM) Errors

If you encounter OOM errors after increasing --cuda-graph-max-bs:
  1. Reduce --mem-fraction-static by 0.01-0.05
  2. Reduce --cuda-graph-max-bs to a smaller value
  3. Monitor available_gpu_mem in logs (should be 5-8 GB)

PCG Capture Failures

If PCG fails to capture:
  1. Add --disable-piecewise-cuda-graph to work around
  2. Check if your model architecture is in the auto-disable list
  3. Report the issue with full error traceback
  4. For custom kernels, ensure they are registered with @register_custom_op

Performance Degradation

If CUDA graph degrades performance:
  1. Check if batch size exceeds --cuda-graph-max-bs (falls back to non-graph path)
  2. Verify memory bandwidth is not saturated
  3. Try different --piecewise-cuda-graph-compiler settings (eager vs inductor)
  4. Monitor for frequent graph breaks in PCG

Code Reference

FileDescription
python/sglang/srt/model_executor/cuda_graph_runner.pyDecode CUDA graph runner
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.pyPiecewise CUDA graph runner
python/sglang/srt/compilation/compile.pyinstall_torch_compiled trampoline
python/sglang/srt/compilation/backend.pySGLangBackend, graph splitting
python/sglang/srt/compilation/cuda_piecewise_backend.pyPer-subgraph CUDA graph capture/replay
python/sglang/srt/utils/custom_op.pyregister_custom_op for torch.compile compatibility

See Also