Skip to main content
Overlap Scheduling is a performance optimization technique that overlaps CPU scheduling overhead with GPU computation, significantly improving overall system throughput and GPU utilization.

Overview

Proposed in the NanoFlow paper, overlap scheduling addresses a key bottleneck in LLM serving: the CPU overhead of scheduling, memory management, and batch preparation can leave the GPU idle between forward passes. Mini-SGLang employs overlap scheduling by default to maximize GPU utilization and minimize latency. Overlap Scheduling Illustration of Overlap Scheduling from LMSYS Blog

The Problem: CPU Bottleneck

Traditional Sequential Execution

Without overlap scheduling, the execution is strictly sequential:
┌─────────────────────────────────────────────────────┐
│ CPU: Schedule Batch 1                               │
└─────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────┐
│ GPU: Execute Batch 1                                │  ← GPU busy
└─────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────┐
│ CPU: Process Results 1 + Schedule Batch 2           │  ← GPU idle!
└─────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────┐
│ GPU: Execute Batch 2                                │  ← GPU busy
└─────────────────────────────────────────────────────┘
The GPU sits idle during CPU scheduling, wasting valuable compute resources.

With Overlap Scheduling

┌─────────────────────────────────────────────────────┐
│ CPU: Schedule Batch 1                               │
└─────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────┐
│ GPU: Execute Batch 1         │ CPU: Schedule Batch 2│  ← Overlapped!
└─────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────┐
│ GPU: Execute Batch 2         │ CPU: Process Res 1   │  ← Overlapped!
│                              │      Schedule Batch 3 │
└─────────────────────────────────────────────────────┘
While the GPU executes Batch N, the CPU simultaneously processes results from Batch N-1 and schedules Batch N+1. The GPU stays continuously busy.

Implementation

Dual Stream Architecture

Mini-SGLang uses two CUDA streams to enable overlap:
class Scheduler(SchedulerIOMixin):
    def __init__(self, config: SchedulerConfig):
        self.engine = Engine(config)
        
        # Main stream for metadata processing
        self.device = self.engine.device
        self.stream = torch.cuda.Stream(device=self.device)
        
        # Engine stream for GPU computation
        self.engine_stream_ctx = torch.cuda.stream(self.engine.stream)
        
        # Set main stream as default
        torch.cuda.set_stream(self.stream)
  • Main stream (self.stream): Used for CPU-side operations like scheduling, memory allocation, metadata preparation
  • Engine stream (self.engine.stream): Used for GPU computation (forward pass, attention, sampling)

Overlap Loop

The main event loop implements the overlap pattern:
def overlap_loop(self, last_data: ForwardData | None) -> ForwardData | None:
    """
    The main loop of overlapping scheduling and execution.
    
    It will overlap the execution of current batch and processing of last batch's results,
    which can effectively hide CPU latency and improve GPU utilization.
    """
    # Step 1: Receive new messages (non-blocking if we have work to do)
    blocking = not (
        last_data is not None
        or self.prefill_manager.runnable
        or self.decode_manager.runnable
    )
    for msg in self.receive_msg(blocking=blocking):
        self._process_one_msg(msg)
    
    # Step 2: Schedule next batch on main stream
    forward_input = self._schedule_next_batch()
    ongoing_data = None
    
    if forward_input is not None:
        # Step 3: Launch GPU computation on engine stream
        with self.engine_stream_ctx:
            self.engine.stream.wait_stream(self.stream)
            ongoing_data = (forward_input, self._forward(forward_input))
    
    # Step 4: Process last batch's results on main stream (overlapped!)
    self._process_last_data(last_data)
    
    return ongoing_data

Key Insights

  1. Asynchronous execution: GPU work is launched asynchronously on the engine stream
  2. Stream synchronization: engine.stream.wait_stream(self.stream) ensures scheduling is done before GPU starts
  3. Overlap window: While GPU processes Batch N, CPU processes results from Batch N-1
  4. Data passing: ForwardData captures both input and output to be processed in the next iteration

ForwardData Structure

The overlap requires caching forward pass data:
class ForwardInput(NamedTuple):
    batch: Batch
    sample_args: BatchSamplingArgs
    input_tuple: Indice2D  # (token_mapping, positions)
    write_tuple: Indice2D  # (req_mapping, seq_lens or 0)

ForwardData: TypeAlias = "Tuple[ForwardInput, ForwardOutput]"
This captures all data needed to process results after the GPU completes.

Processing Last Batch Results

While the GPU works on the current batch, the CPU processes the previous batch:
def _process_last_data(self, last_data: ForwardData | None) -> None:
    if last_data is None:
        return
    
    batch, (_, next_tokens_cpu, copy_done) = last_data[0].batch, last_data[1]
    copy_done.synchronize()  # Wait for GPU->CPU copy
    
    reply: List[DetokenizeMsg] = []
    new_finished_reqs: Set[Req] = set()
    
    with self.cache_manager.lazy_free_region():
        for i, req in enumerate(batch.reqs):
            if isinstance(req, ChunkedReq):
                continue
            
            next_token = next_tokens_cpu[i]
            req.append_host(next_token.unsqueeze(0))
            next_token = int(next_token.item())
            
            # Check for completion
            finished = not req.can_decode
            if not req.sampling_params.ignore_eos:
                finished |= next_token == self.eos_token_id
            
            reply.append(DetokenizeMsg(
                uid=req.uid,
                next_token=next_token,
                finished=finished
            ))
            
            # Free finished requests
            if finished and req not in self.finished_reqs:
                self.decode_manager.remove_req(req)
                self._free_req_resources(req)
                new_finished_reqs.add(req)
            elif batch.is_prefill:
                self.cache_manager.cache_req(req, finished=False)
    
    self.finished_reqs = new_finished_reqs
    self.send_result(reply)
This performs:
  • Token extraction from GPU results
  • Completion detection (EOS or max length)
  • Resource cleanup for finished requests
  • Detokenization message sending
  • Cache updates
All while the GPU is busy with the next batch!

Stream Synchronization

Wait Pattern

The synchronization ensures correct ordering:
# Schedule batch on main stream
forward_input = self._schedule_next_batch()  # On self.stream

if forward_input is not None:
    with self.engine_stream_ctx:  # Switch to engine stream
        # Wait for main stream to finish scheduling
        self.engine.stream.wait_stream(self.stream)
        
        # Now safe to execute on engine stream
        ongoing_data = (forward_input, self._forward(forward_input))

# Back on main stream, process last batch (overlapped with GPU)
self._process_last_data(last_data)

Event Synchronization

The copy_done event ensures GPU->CPU copies complete:
# In forward pass
forward_output = self.engine.forward_batch(batch, sample_args)
# This includes: copy_done = torch.cuda.Event()

# In process_last_data
copy_done.synchronize()  # Wait for CPU tensor to be ready
next_token = next_tokens_cpu[i]  # Now safe to access
Stream synchronization is critical for correctness. The wait_stream call ensures all scheduling work completes before GPU execution begins, while copy_done.synchronize() ensures CPU can safely read GPU results.

Performance Benefits

Benchmark Results

From the Mini-SGLang README: Offline Inference Benchmark
  • Hardware: 1x H200 GPU
  • Model: Qwen3-0.6B, Qwen3-14B
  • Workload: 256 sequences, 100-1024 tokens in/out
To test without overlap scheduling:
MINISGL_DISABLE_OVERLAP_SCHEDULING=1 python benchmark/offline/bench.py
Offline Benchmark Overlap scheduling provides measurable throughput improvements by keeping the GPU saturated.

When Overlap Helps Most

  1. High CPU overhead: Complex scheduling logic, large batches, many requests
  2. Fast GPU execution: Small models or short sequences where CPU overhead is significant relative to GPU time
  3. Mixed workloads: Combination of prefill and decode requests requiring dynamic scheduling

When Overlap Helps Less

  1. Very large models: GPU execution dominates, little CPU time to hide
  2. Simple scheduling: Minimal CPU overhead leaves little to optimize
  3. Memory bound: If memory copies dominate, overlap may not help significantly

Disabling Overlap Scheduling

For debugging or comparison purposes, overlap scheduling can be disabled:
# Set environment variable
MINISGL_DISABLE_OVERLAP_SCHEDULING=1 python -m minisgl --model "Qwen/Qwen3-0.6B"
With overlap disabled, the system uses the normal loop:
def normal_loop(self) -> None:
    blocking = not (self.prefill_manager.runnable or self.decode_manager.runnable)
    for msg in self.receive_msg(blocking=blocking):
        self._process_one_msg(msg)
    
    forward_input = self._schedule_next_batch()
    ongoing_data = None
    if forward_input is not None:
        ongoing_data = (forward_input, self._forward(forward_input))
    
    # Process immediately, no overlap
    self._process_last_data(ongoing_data)
This processes each batch completely before moving to the next, running everything sequentially.
Disabling overlap scheduling can be useful for debugging race conditions or isolating performance issues, but it will reduce throughput in production.

Advanced: Extra Synchronization

For certain edge cases, additional synchronization can be enabled:
# Set environment variable for extra sync
MINISGL_OVERLAP_EXTRA_SYNC=1 python -m minisgl --model "Qwen/Qwen3-0.6B"
This adds an explicit stream sync before the forward pass:
def _forward(self, forward_input: ForwardInput) -> ForwardOutput:
    batch, sample_args, input_mapping, output_mapping = forward_input
    batch.input_ids = self.token_pool[input_mapping]
    
    if ENV.OVERLAP_EXTRA_SYNC:
        # Explicit synchronization for edge cases
        # See: https://github.com/sgl-project/mini-sglang/issues/58
        self.stream.synchronize()
    
    forward_output = self.engine.forward_batch(batch, sample_args)
    self.token_pool[output_mapping] = forward_output.next_tokens_gpu
    self.decode_manager.filter_reqs(forward_input.batch.reqs)
    return forward_output
This ensures complete quiescence of the main stream before GPU work begins, trading some performance for additional safety.

Architecture

Understand the scheduler’s role in the system

Chunked Prefill

See how scheduling interacts with chunked prefill

References

Build docs developers (and LLMs) love