Why CUDA Graph?
In LLM inference, the model forward pass consists of hundreds or thousands of small kernel launches. Each kernel launch has overhead:- CPU-GPU synchronization
- Kernel parameter setup
- GPU scheduler overhead
- Capturing a sequence of GPU operations with fixed shapes and memory addresses into a graph
- Replaying the entire graph with a single launch, dramatically reducing overhead
- Improving GPU utilization through more compact scheduling
- Small batch sizes (where launch overhead is proportionally larger)
- Decode phase (naturally has fixed batch size)
- Models with many layers and fragmented operators
CUDA Graph for Decode
Overview
Decode phase has a natural advantage for CUDA graph: the batch size is relatively stable, and each token generates exactly one new token. SGLang automatically enables CUDA graph for decode by default.How It Works
- Initialization: Pre-allocate static buffers for various batch sizes
- Capture: For each batch size, run a dummy forward pass while recording GPU operations
- Replay: At runtime, copy inputs into static buffers and replay the captured graph
Configuration
Maximum batch size for which to capture CUDA graphs. By default, CUDA graph is enabled for small batch sizes (e.g., less than 160 or 256). For some models, especially at large tensor parallelism sizes, CUDA graph can be useful for batch sizes up to 512 or 768.
CUDA graph consumes additional GPU memory. If you increase
--cuda-graph-max-bs, you may need to reduce --mem-fraction-static to prevent OOM errors.Example
Piecewise CUDA Graph (PCG)
Overview
Standard CUDA graphs capture the entire model forward pass as a single graph, which works well for decode (fixed batch size) but not for extend/prefill where the number of tokens varies across iterations. Piecewise CUDA Graph (PCG) solves this by:- Splitting the model’s computation graph into pieces (roughly one per layer) at “split points” (e.g., MoE dispatch ops)
- Capturing each piece as a separate CUDA graph for a set of pre-defined token lengths
- At runtime, padding the input to the nearest captured size and replaying each piece
PCG is enabled by default for supported configurations. The old
--enable-piecewise-cuda-graph flag is deprecated. Use --disable-piecewise-cuda-graph to turn it off.Usage
PCG is enabled by default. No extra flags needed:Disable PCG
Custom Capture Sizes
Configuration Parameters
Disable PCG for extend/prefill.
Force-enable PCG, skipping all auto-disable conditions. For testing only.
Maximum token count to capture. Defaults to
chunked_prefill_size (non-MLA) or 2048 (MLA).Explicit list of token lengths to capture. Auto-generated if not set.
Compiler backend for the captured subgraphs. Choices:
eager, inductor.How It Works
Torch Compile Backend
PCG usestorch.compile with a custom backend (SGLangBackend) to split and compile the model’s forward pass:
model.forward with a wrapper function that dispatches to compiled callable when PCG is active
Split: SGLangBackend receives the FX graph and cuts it at split points (attention ops, all-reduce ops, etc.)
Replace: Each capturable submodule is compiled and replaced with a CUDAPiecewiseBackend instance
Dispatch: At runtime, split-op submodules run eagerly, while CUDAPiecewiseBackend submodules go through:
- Compile warmup — runs the general-shape compiled path
- Capture — for each capture size, runs one warmup pass then records a CUDA graph
- Steady-state replay — replays the captured CUDA graph
Piecewise CUDA Graph Runner
PiecewiseCudaGraphRunner orchestrates the full lifecycle:
- Compile: Warms up JIT kernels, wraps model with
torch.compile, triggers Dynamo tracing - Capture: Iterates over capture sizes in reverse order (largest first), captures CUDA graphs
- Replay: At runtime, finds smallest captured size >= actual token count, copies inputs with zero-padding, replays graphs, slices outputs
Shape Configuration
The default capture schedule is auto-generated with increasing granularity:| Token range | Step size |
|---|---|
| 4 – 32 | 4 |
| 48 – 256 | 16 |
| 288 – 512 | 32 |
| 576 – 1024 | 64 |
| 1280 – 4096 | 256 |
| 4096+ | 512 |
--piecewise-cuda-graph-max-tokens. If the token count exceeds the largest captured size, the runtime falls back to the normal (non-graph) forward path.
Memory Optimization
The memory cost comes from two parts:-
Torch memory allocator: Trivial overhead thanks to:
- Global shared memory pool reused across all runners and capture sizes
- Reverse-order capture (large to small) allows smaller graphs to reuse memory
- Output tensors stored as weak references
-
Non-torch memory: CUDA graph objects require GPU memory to store recorded kernel launch parameters. This scales with the number of captured sizes, which is why
piecewise_cuda_graph_max_tokensis capped conservatively.
Compatibility
PCG is auto-disabled in the following scenarios:- Disabled model architectures (e.g.,
DeepseekV32ForCausalLM) - Speculative decoding
- DP attention
- Pipeline parallelism (
pp_size > 1) - Non-CUDA hardware (AMD ROCm, Ascend NPU)
- MoE A2A backend
- LoRA
- Multimodal / VLM models
- DLLM (diffusion LLM)
- Deterministic inference
- PD disaggregation
- Expert distribution recorder / EPLB
--enforce-piecewise-cuda-graph to skip all auto-disable checks (for testing/debugging only).
Bug Report
If you encounter errors during server startup:--disable-piecewise-cuda-graph to your launch command.
When filing a bug report, please include:
- Full error traceback
- Model name and quantization method
- Launch command with all arguments
- GPU type and driver version
Developer Guide: Making Kernels Compatible
Since PCG relies ontorch.compile, newly developed CUDA kernels are typically not compatible out of the box. To make a kernel compatible, register it as a custom op:
register_custom_op_from_extern. See python/sglang/srt/utils/custom_op.py for full API documentation.
CUDA Graph for Vision Transformers
Overview
In multimodal reasoning services, the visual encoder (ViT / Vision Transformer) typically has:- Many layers with fragmented operators (LN, QKV projections, attention, MLP, residuals)
- Extremely frequent kernel launches
- Server-side small batch / low latency scenarios where kernel launch overhead is significant
- Variable input token count (different image/video resolutions)
Usage
Enable CUDA Graph for ViT by setting the environment variable:Design Considerations
Dynamic inputs to fit static constraints:- Build a graph cache by sequence length S (graph_key = S)
- First time creates and captures a new graph; afterwards replays it
- Many distinct S values increase VRAM usage for graph-private memory pools
- Everything parameter-like becomes a static buffer (block_input, block_ws, block_output, cu_full_len, sin_cos_ws)
- During replay, tensor contents are modified but tensors are not swapped
- Arguments are fixed inside the graph (cu_seqlens, max_len)
- For the same graph_key = S, requires identical segmentation pattern in cu_seqlens
- Reallocates larger sin_cos_ws when seq_len increases
- max_content_len ensures maximum size of allocated rotary buffer
Supported Models
- Qwen2.5-VL
- Qwen3-VL
Best Practices
Decode CUDA Graph
Increase
--cuda-graph-max-bs for large TP sizes, but monitor memory usage and adjust --mem-fraction-static accordingly.Piecewise CUDA Graph
Let PCG auto-enable. Only disable if you encounter issues. Report bugs to help improve the feature.
Memory Management
CUDA graphs consume additional memory. Balance between CUDA graph coverage and KV cache pool size.
Custom Kernels
Register custom ops with
@register_custom_op to make kernels compatible with PCG.Performance Impact
Decode Phase
CUDA graph for decode typically provides:- 10-30% speedup for small batch sizes (1-32)
- 5-15% speedup for medium batch sizes (32-128)
- Diminishing returns for large batch sizes (>256)
- Hopper and Blackwell GPUs (more efficient graph execution)
- Models with many small kernels
- High tensor parallelism (more communication ops to capture)
Prefill Phase (PCG)
Piecewise CUDA graph for prefill typically provides:- 15-40% speedup for short sequences (64-512 tokens)
- 10-25% speedup for medium sequences (512-2048 tokens)
- 5-15% speedup for long sequences (2048-4096 tokens)
- Small batch sizes where kernel launch overhead dominates
- Models with many layers
- Frequent prefill operations (e.g., chatbot workloads)
Vision Transformers
CUDA graph for ViT typically provides:- 20-50% speedup for small batches
- Larger speedups on Hopper/Blackwell GPUs
- Most beneficial when serving low-latency multimodal workloads
Troubleshooting
Out of Memory (OOM) Errors
If you encounter OOM errors after increasing--cuda-graph-max-bs:
- Reduce
--mem-fraction-staticby 0.01-0.05 - Reduce
--cuda-graph-max-bsto a smaller value - Monitor
available_gpu_memin logs (should be 5-8 GB)
PCG Capture Failures
If PCG fails to capture:- Add
--disable-piecewise-cuda-graphto work around - Check if your model architecture is in the auto-disable list
- Report the issue with full error traceback
- For custom kernels, ensure they are registered with
@register_custom_op
Performance Degradation
If CUDA graph degrades performance:- Check if batch size exceeds
--cuda-graph-max-bs(falls back to non-graph path) - Verify memory bandwidth is not saturated
- Try different
--piecewise-cuda-graph-compilersettings (eager vs inductor) - Monitor for frequent graph breaks in PCG
Code Reference
| File | Description |
|---|---|
python/sglang/srt/model_executor/cuda_graph_runner.py | Decode CUDA graph runner |
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py | Piecewise CUDA graph runner |
python/sglang/srt/compilation/compile.py | install_torch_compiled trampoline |
python/sglang/srt/compilation/backend.py | SGLangBackend, graph splitting |
python/sglang/srt/compilation/cuda_piecewise_backend.py | Per-subgraph CUDA graph capture/replay |
python/sglang/srt/utils/custom_op.py | register_custom_op for torch.compile compatibility |
