Skip to main content
This guide provides comprehensive strategies for optimizing SGLang performance across different workloads and deployment scenarios.

Achieving High Throughput for Offline Batch Inference

Achieving a large batch size is the most important factor for attaining high throughput in offline batch inference. When the server is running at full load in a steady state, look for log entries like:
Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, 
cuda graph: True, gen throughput (token/s): 4594.01, #queue-req: 317

Control Queue Size

#queue-req
metric
Number of requests in the queue. A healthy range is 100-2000.
Diagnosis:
  • #queue-req: 0 frequently: Client code is submitting requests too slowly. Increase request submission rate.
  • #queue-req > 2000 frequently: Too many queued requests increase scheduling overhead. Reduce request submission rate.

Maximize Token Usage

token usage
metric
KV cache memory utilization of the server. Target: > 0.9 for good utilization.
Diagnosis:
  • token usage < 0.9 and #queue-req > 0 frequently: Server is too conservative about taking new requests.
    • Solution: Decrease --schedule-conservativeness to a value like 0.3
    • Common cause: Users send many requests with large max_new_tokens but requests stop early due to EOS or stop strings
  • token usage very high and frequent warnings like:
    KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_token_ratio: 0.9998 -> 1.0000
    
    • Solution: Increase --schedule-conservativeness to a value like 1.3
    • Note: Occasional retractions (~1 time per minute) are acceptable

Tune Memory Allocation

SGLang allocates memory as follows:
Total memory usage = model weights + KV cache pool + CUDA graph buffers + activations
--mem-fraction-static
float
Determines memory allocation for the first two components:
mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity
To support higher concurrency, maximize KV cache pool capacity by setting --mem-fraction-static as high as possible while reserving enough memory for activations and CUDA graph buffers. Optimization Process:
  1. Check available GPU memory in logs before server is ready:
    max_total_num_tokens=665690, chunked_prefill_size=8192, max_prefill_tokens=16384, 
    max_running_requests=4096, context_len=65536, available_gpu_mem=13.50 GB
    
  2. Evaluate available_gpu_mem:
    • 5-8 GB: Good setting
    • 10-20 GB: Too high, increase --mem-fraction-static to allocate more to KV cache
    • < 5 GB: Too low, risk of OOM errors, decrease --mem-fraction-static
  3. Alternative approach: Increase --mem-fraction-static in increments of 0.01 until you encounter OOM errors for your workloads
As a rule of thumb, reserving 5–8 GB of memory for activations is typically sufficient.

Avoid Out-of-Memory Errors

If you encounter OOM errors:
  • Reduce --chunked-prefill-size to 4096 or 2048
  • Tradeoff: Saves memory but slows down prefill for long prompts
  • Lower --max-running-requests
  • Tradeoff: Limits maximum concurrency
  • Reduce --mem-fraction-static to 0.8 or 0.7
  • Tradeoff: Decreases KV cache capacity, limits peak throughput

Tune CUDA Graph Coverage

--cuda-graph-max-bs
integer
Maximum batch size for CUDA graph capture. Default varies by model (typically 160-256).
By default, CUDA graph is enabled only for small batch sizes. However, for some models (especially at large tensor parallelism sizes), CUDA graph can be beneficial for batch sizes up to 512 or 768. Recommendation:
  • Increase --cuda-graph-max-bs to a larger value (e.g., 512, 768)
  • Important: CUDA graph consumes more memory, so reduce --mem-fraction-static at the same time
python3 -m sglang.launch_server \
  --model-path your-model \
  --cuda-graph-max-bs 512 \
  --mem-fraction-static 0.85

Optimize Parallelism Strategy

--dp-size
integer
Data parallelism size. Better for throughput than tensor parallelism when GPU memory allows.
--tp-size
integer
Tensor parallelism size. Required for large models that don’t fit on a single GPU.
Guidelines:
  • Data parallelism is better for throughput: When there is enough GPU memory, always favor data parallelism
  • Use SGLang Model Gateway: For better data parallelism management rather than using --dp-size parameter
  • Tensor parallelism: Use only when model doesn’t fit on a single GPU

Additional Optimizations

Torch Compile

Accelerates small models on small batch sizes.
--enable-torch-compile

FP8 Quantization

Reduces memory footprint and improves throughput.
--quantization fp8

Expert Parallelism

For MoE models, distribute experts across GPUs.See Expert Parallelism blog

DP Attention

For DeepSeek models with data parallelism.
--enable-dp-attention --dp-size 8

Use Longest Prefix Match Scheduling

--schedule-policy
string
default:"fcfs"
Scheduling policy for requests. Options: fcfs, lpm.
If the workload has many shared prefixes:
--schedule-policy lpm
Tradeoff:
  • lpm (Longest Prefix Match) reorders requests to encourage more cache hits
  • Introduces more scheduling overhead
  • Best for workloads with high prefix reuse (e.g., many similar prompts)

Optimizing for Different Workloads

Online Serving (Low Latency)

Priorities: Low latency, consistent response times
1

Reduce Batch Size

Keep batch sizes small to minimize queueing delay:
--max-running-requests 128
2

Enable CUDA Graph

Maximize CUDA graph coverage for small batches:
--cuda-graph-max-bs 256
3

Use Fast Attention Backend

Choose the fastest backend for your hardware:
  • Hopper: --attention-backend fa3
  • Blackwell: --attention-backend trtllm_mha or --attention-backend trtllm_mla
  • Ampere/Ada: --attention-backend flashinfer
4

Conservative Scheduling

Avoid retraction overhead:
--schedule-conservativeness 1.2

Offline Batch Processing (High Throughput)

Priorities: Maximum throughput, high GPU utilization
1

Maximize Batch Size

Increase KV cache pool and running requests:
--mem-fraction-static 0.90 \
--max-running-requests 4096
2

Aggressive Scheduling

Take more risks to maximize utilization:
--schedule-conservativeness 0.3
3

Large Chunked Prefill

Process long prompts efficiently:
--chunked-prefill-size 16384
4

Data Parallelism

Use data parallelism for maximum throughput:
--dp-size 8

Long-Context Workloads

Priorities: Support long sequences, maximize prefix reuse
1

Enable HiCache

Use hierarchical caching for long contexts:
--enable-hierarchical-cache \
--hicache-ratio 2 \
--hicache-storage-backend hf3fs
2

Optimize Page Size

Balance cache hit rate and memory efficiency:
--page-size 64
3

Use Prefix Match Scheduling

Maximize cache reuse:
--schedule-policy lpm
4

Large Chunked Prefill

Handle long prompts efficiently:
--chunked-prefill-size 16384

Multi-turn Conversations

Priorities: Reuse conversational context, low latency for follow-ups
1

Enable RadixAttention

Automatically enabled, ensure not disabled
2

Optimize Page Size

Token-level matching for maximum reuse:
--page-size 1
3

Use HiCache with PD Disaggregation

Share KV cache between prefill and decode:
--enable-hierarchical-cache \
--hicache-storage-backend hf3fs \
--disaggregation-mode decode \
--disaggregation-decode-enable-offload-kvcache

Monitoring and Metrics

Key Metrics to Monitor

gen throughput (token/s)
metric
Generation throughput in tokens per second. Primary metric for performance.
token usage
metric
KV cache memory utilization. Target: > 0.9 for good utilization.
#running-req
metric
Number of requests currently being processed. Should be close to --max-running-requests under load.
#queue-req
metric
Number of requests in the queue. Healthy range: 100-2000.
cuda graph
boolean
Whether CUDA graph is active for the current batch. Should be True for small batches.

Enable Metrics Collection

python3 -m sglang.launch_server \
  --model-path your-model \
  --enable-metrics \
  --enable-cache-report
Access metrics:
  • Prometheus endpoint: http://localhost:30000/metrics
  • Cache report: Periodic logs showing cache hit rates

Troubleshooting Performance Issues

Low Throughput

Symptoms: gen throughput significantly lower than expected
If token usage < 0.9:
  • Decrease --schedule-conservativeness
  • Increase --mem-fraction-static
If #queue-req: 0 frequently:
  • Increase request submission rate
  • Client is the bottleneck, not the server
If cuda graph: False for small batches:
  • Increase --cuda-graph-max-bs
  • Verify CUDA graph is not disabled
  • Verify using optimal backend for your hardware
  • Try different backends and benchmark

High Latency

Symptoms: Requests take longer than expected to complete
If batch size is too large:
  • Reduce --max-running-requests
  • Trade throughput for lower latency
If #queue-req is very high:
  • Reduce request submission rate
  • Requests are waiting too long in queue
If long prompts dominate:
  • Increase --chunked-prefill-size
  • Enable Piecewise CUDA Graph

Memory Issues

Symptoms: OOM errors, frequent retractions
  • Decrease --mem-fraction-static
  • Reduce --cuda-graph-max-bs
  • Reduce --chunked-prefill-size
  • Enable quantized KV cache: --kv-cache-dtype fp8_e4m3
  • Use FP8 weight quantization: --quantization fp8
If frequent retractions:
  • Increase --schedule-conservativeness

Best Practices Summary

Start Conservative

Begin with default settings and tune incrementally based on metrics.

Monitor Metrics

Enable metrics and cache reporting to make data-driven tuning decisions.

Workload-Specific Tuning

Optimize for your specific workload characteristics (online vs. offline, long vs. short context).

Benchmark Regularly

Test performance after each configuration change to validate improvements.

Balance Tradeoffs

Understand the tradeoffs between latency, throughput, and memory usage.

Use Latest Features

Leverage HiCache, PCG, and optimized attention backends for best performance.

See Also