Skip to main content

Scheduler

The scheduler is the core component that manages request batching, memory allocation, and execution orchestration in SGLang.

Overview

Location: python/sglang/srt/managers/scheduler.py Key Responsibilities:
  • Request queueing and prioritization
  • Dynamic batch formation
  • Memory allocation via token-to-KV pool
  • Prefix cache management (RadixAttention)
  • Request lifecycle management

Request States

A request transitions through several states:
┌─────────┐
│ Waiting │  Initial state, in waiting queue
└────┬────┘


┌─────────┐
│ Running │  Executing in a batch
└────┬────┘

     ├──→ ┌──────────┐
     │    │ Finished │  Generation complete
     │    └──────────┘

     └──→ ┌──────────┐
          │ Aborted  │  Cancelled by user
          └──────────┘

Scheduling Loop

Main Loop

def event_loop(self):
    """Main scheduling loop."""
    while True:
        # Receive requests from tokenizer manager
        recv_reqs = self.recv_requests()
        
        # Add to waiting queue
        for req in recv_reqs:
            self.waiting_queue.append(req)
        
        # Process one step
        if self.running_batch or self.waiting_queue:
            self.process_batch()
        
        # Send results to detokenizer
        self.send_results()

Batch Processing

def process_batch(self):
    """Process one batch."""
    # 1. Get next batch
    batch = self.get_next_batch()
    
    # 2. Prefill new requests
    if batch.has_prefill:
        self.run_prefill(batch)
    
    # 3. Decode existing requests
    if batch.has_decode:
        self.run_decode(batch)
    
    # 4. Sample tokens
    next_tokens = self.sample_tokens(batch)
    
    # 5. Update requests
    self.update_requests(batch, next_tokens)
    
    # 6. Check finish conditions
    self.check_finished(batch)

Dynamic Batching

Batch Formation

The scheduler dynamically forms batches based on:
  • Available memory
  • Request priorities
  • Prefill chunking constraints
def get_next_batch(self):
    """Form next batch from waiting and running requests."""
    batch = ScheduleBatch()
    
    # Add running requests (already executing)
    for req in self.running_batch:
        batch.add_decode_req(req)
    
    # Add new requests from waiting queue
    while self.waiting_queue:
        req = self.waiting_queue[0]
        
        # Check if we have memory
        if not self.can_allocate_kv_cache(req):
            break
        
        # Check if we should chunk prefill
        if self.should_chunk_prefill(req):
            chunk_size = self.get_prefill_chunk_size()
            batch.add_prefill_req(req, chunk_size)
        else:
            batch.add_prefill_req(req, len(req.input_ids))
        
        self.waiting_queue.pop(0)
        self.running_batch.append(req)
    
    return batch

Continuous Batching

Requests can join or leave batches at any time:
# Iteration 0
batch = [req1, req2, req3]  # Initial batch

# Iteration 1: req4 arrives, req2 finishes
batch = [req1, req3, req4]

# Iteration 2: req5, req6 arrive, req1 finishes
batch = [req3, req4, req5, req6]
This maximizes GPU utilization compared to static batching.

Chunked Prefill

Why Chunk?

Large prefills can block decode requests, increasing latency:
Without chunking:
[────── Long prefill (2s) ──────][Decode][Decode][Decode]
                                  ↑ High latency

With chunking:
[Prefill chunk][Decode][Prefill chunk][Decode][Prefill chunk][Decode]
                 ↑ Low latency

Implementation

def should_chunk_prefill(self, req):
    """Decide if request should be chunked."""
    return len(req.input_ids) > self.max_prefill_tokens

def get_prefill_chunk_size(self):
    """Determine chunk size based on current load."""
    if len(self.running_batch) > 10:  # Many decode requests
        return 512  # Small chunks for low latency
    else:
        return 2048  # Large chunks for high throughput

Memory Management

Token-to-KV Pool

The scheduler allocates KV cache via a memory pool:
class MemoryPool:
    def __init__(self, total_size):
        self.total_size = total_size
        self.free_blocks = [Block(0, total_size)]  # Initially all free
    
    def allocate(self, size):
        """Allocate memory block."""
        for block in self.free_blocks:
            if block.size >= size:
                # Split block
                allocated = Block(block.start, size)
                remaining = Block(block.start + size, block.size - size)
                
                self.free_blocks.remove(block)
                if remaining.size > 0:
                    self.free_blocks.append(remaining)
                
                return allocated
        
        return None  # OOM
    
    def free(self, block):
        """Free memory block."""
        self.free_blocks.append(block)
        self.merge_adjacent_blocks()  # Coalesce

Eviction Policy

When memory is full, the scheduler can evict cached prefixes:
def evict_cache(self, required_size):
    """Evict cached prefixes to free memory."""
    # LRU eviction
    candidates = sorted(
        self.cached_prefixes,
        key=lambda x: x.last_access_time
    )
    
    freed = 0
    for prefix in candidates:
        if freed >= required_size:
            break
        
        # Evict this prefix
        self.free_kv_cache(prefix)
        freed += prefix.cache_size
        
    return freed >= required_size

RadixAttention (Prefix Caching)

Radix Tree Structure

The scheduler maintains a radix tree to track shared prefixes:
class RadixNode:
    def __init__(self):
        self.children = {}  # token -> RadixNode
        self.kv_cache_indices = None  # Where KV cache is stored
        self.ref_count = 0  # Number of requests using this prefix

Prefix Matching

When a new request arrives:
def match_prefix(self, tokens):
    """Find longest matching prefix in radix tree."""
    node = self.radix_tree_root
    matched_len = 0
    
    for i, token in enumerate(tokens):
        if token in node.children:
            node = node.children[token]
            matched_len = i + 1
        else:
            break
    
    # Reuse KV cache for matched tokens
    if matched_len > 0:
        return node.kv_cache_indices[:matched_len]
    
    return None

Cache Insertion

After computing new KV cache:
def insert_cache(self, tokens, kv_indices):
    """Insert new prefix into radix tree."""
    node = self.radix_tree_root
    
    for i, token in enumerate(tokens):
        if token not in node.children:
            node.children[token] = RadixNode()
        
        node = node.children[token]
        node.kv_cache_indices = kv_indices[:i+1]
        node.ref_count += 1

Request Prioritization

Priority Levels

Requests can have different priorities:
class Priority:
    HIGH = 2
    NORMAL = 1
    LOW = 0

Scheduling with Priority

def get_next_requests(self):
    """Get next requests sorted by priority."""
    # Sort waiting queue by priority
    self.waiting_queue.sort(
        key=lambda req: (req.priority, req.arrival_time),
        reverse=True
    )
    
    # Schedule high-priority requests first
    batch = []
    for req in self.waiting_queue:
        if self.can_allocate(req):
            batch.append(req)
            if len(batch) >= self.max_batch_size:
                break
    
    return batch

Sampling

Token Sampling

After model forward pass, sample next tokens:
def sample_tokens(self, batch, logits):
    """Sample next tokens for batch."""
    next_tokens = []
    
    for i, req in enumerate(batch.reqs):
        # Get logits for this request
        req_logits = logits[i, -1, :]  # Last token
        
        # Apply penalties
        req_logits = self.apply_penalties(
            req_logits,
            req.output_ids,
            req.sampling_params
        )
        
        # Sample
        token = self.sampler.sample(
            req_logits,
            temperature=req.sampling_params.temperature,
            top_p=req.sampling_params.top_p,
            top_k=req.sampling_params.top_k,
        )
        
        next_tokens.append(token)
    
    return next_tokens

Penalties

def apply_penalties(self, logits, output_ids, params):
    """Apply frequency and presence penalties."""
    # Frequency penalty
    if params.frequency_penalty != 0:
        for token_id in output_ids:
            count = output_ids.count(token_id)
            logits[token_id] -= params.frequency_penalty * count
    
    # Presence penalty
    if params.presence_penalty != 0:
        for token_id in set(output_ids):
            logits[token_id] -= params.presence_penalty
    
    # Repetition penalty
    if params.repetition_penalty != 1.0:
        for token_id in set(output_ids):
            if logits[token_id] < 0:
                logits[token_id] *= params.repetition_penalty
            else:
                logits[token_id] /= params.repetition_penalty
    
    return logits

Finish Conditions

Checking Completion

def check_finished(self, batch):
    """Check which requests have finished."""
    finished_reqs = []
    
    for req in batch.reqs:
        # Check stop conditions
        if self.is_finished(req):
            finished_reqs.append(req)
    
    # Remove finished requests from running batch
    for req in finished_reqs:
        self.running_batch.remove(req)
        self.free_request_resources(req)
    
    return finished_reqs

def is_finished(self, req):
    """Check if request is finished."""
    # Check max tokens
    if len(req.output_ids) >= req.sampling_params.max_new_tokens:
        req.finish_reason = "length"
        return True
    
    # Check EOS token
    if req.output_ids[-1] == req.tokenizer.eos_token_id:
        if not req.sampling_params.ignore_eos:
            req.finish_reason = "stop"
            return True
    
    # Check stop strings
    if req.sampling_params.stop:
        text = req.tokenizer.decode(req.output_ids)
        for stop_str in req.sampling_params.stop:
            if stop_str in text:
                req.finish_reason = "stop"
                req.matched_stop = stop_str
                return True
    
    return False

Performance Tuning

Key Parameters

class SchedulerConfig:
    # Batch size
    max_batch_size: int = 256
    
    # Chunked prefill
    max_prefill_tokens: int = 16384  # Chunk if longer
    prefill_chunk_size: int = 512    # Chunk size
    
    # Memory
    mem_fraction_static: float = 0.9  # For model + KV cache
    
    # Radix cache
    enable_radix_cache: bool = True
    radix_cache_size: int = 1024 * 1024 * 1024  # 1GB

Monitoring

def get_stats(self):
    """Get scheduler statistics."""
    return {
        "waiting_queue_len": len(self.waiting_queue),
        "running_batch_size": len(self.running_batch),
        "cache_hit_rate": self.cache_hits / self.total_requests,
        "avg_batch_size": self.total_batch_size / self.num_batches,
        "memory_usage": self.memory_pool.used / self.memory_pool.total,
    }

Advanced Features

Speculative Decoding

Use a small draft model to speculate future tokens:
def speculative_decode(self, req):
    """Generate speculative tokens with draft model."""
    # Generate K tokens with draft model
    draft_tokens = self.draft_model.generate(req.input_ids, k=5)
    
    # Verify with target model
    target_logits = self.target_model(draft_tokens)
    
    # Accept/reject speculative tokens
    accepted = self.verify_tokens(draft_tokens, target_logits)
    
    return accepted

Multi-Model Scheduling

Schedule across multiple model replicas:
def route_request(self, req):
    """Route request to least loaded model replica."""
    # Find replica with smallest queue
    replica = min(
        self.replicas,
        key=lambda r: len(r.waiting_queue)
    )
    
    replica.add_request(req)
    return replica

Debugging

Enable Scheduler Logging

import logging
logging.getLogger("sglang.srt.managers.scheduler").setLevel(logging.DEBUG)

Trace Request

def trace_request(self, rid):
    """Trace request through scheduler."""
    logger.info(f"Request {rid} added to waiting queue")
    logger.info(f"Request {rid} matched prefix of length {matched_len}")
    logger.info(f"Request {rid} allocated {cache_size} cache")
    logger.info(f"Request {rid} started execution")
    logger.info(f"Request {rid} finished with reason: {finish_reason}")

Resources

Next Steps