Skip to main content
SGLang enables efficient serving of LoRA adapters with a base model. Using techniques from S-LoRA and Punica, SGLang can serve multiple LoRA adapters for different sequences within a single batch.

Quick Start

Basic LoRA Serving

Launch a server with a single LoRA adapter:
python -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --enable-lora \
    --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \
    --max-loras-per-batch 2
Make a request using the adapter:
import requests

response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "List 3 countries and their capitals.",
        "sampling_params": {"max_new_tokens": 32, "temperature": 0},
        "lora_path": "lora0",
    },
)

print(response.json()["text"])

Multiple Adapters

Serve multiple LoRA adapters simultaneously:
python -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --enable-lora \
    --lora-paths \
        lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \
        lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_SFT_lora_4_alpha_16_humaneval_raw_json \
    --max-loras-per-batch 2
Batch requests with different adapters:
response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": [
            "List 3 countries and their capitals.",
            "Write a Python function to sort a list.",
        ],
        "sampling_params": {"max_new_tokens": 64, "temperature": 0},
        "lora_path": ["lora0", "lora1"],  # Different adapter per request
    },
)

Configuration Parameters

Server Arguments

ParameterDescriptionDefault
--enable-loraEnable LoRA supportAuto-enabled if --lora-paths provided
--lora-pathsList of LoRA adapters to load at startupNone
--max-loras-per-batchMaximum adapters per batch8
--max-lora-rankMaximum LoRA rank to supportAuto-inferred from adapters
--lora-target-modulesTarget modules for LoRA (e.g., q_proj, k_proj)Auto-inferred or all
--lora-backendBackend: triton or csgmvcsgmv
--max-loaded-lorasMaximum adapters in CPU memoryUnlimited
--lora-eviction-policyEviction policy: lru or fifolru
--enable-lora-overlap-loadingOverlap H2D transfers with computeFalse
--max-lora-chunk-sizeChunk size for ChunkedSGMV backend16

LoRA Path Formats

You can specify adapters in multiple formats:
# Simple path
--lora-paths /path/to/adapter

# Named adapter
--lora-paths adapter1=/path/to/adapter1

# JSON with pinning
--lora-paths '{"lora_name":"adapter1","lora_path":"/path","pinned":true}'

Dynamic Adapter Management

Load and unload adapters at runtime without restarting the server.

Initial Server Setup

python -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --enable-lora \
    --max-loras-per-batch 2 \
    --max-lora-rank 256 \
    --lora-target-modules all
When using dynamic loading, explicitly specify --max-lora-rank and --lora-target-modules to ensure compatibility with all adapters you plan to load.

Load Adapter

import requests

response = requests.post(
    "http://localhost:30000/load_lora_adapter",
    json={
        "lora_name": "adapter1",
        "lora_path": "algoprog/fact-generation-llama-3.1-8b-instruct-lora",
        "pinned": False,  # Optional: pin adapter in GPU memory
    },
)

if response.status_code == 200:
    print("Adapter loaded:", response.json())

Unload Adapter

response = requests.post(
    "http://localhost:30000/unload_lora_adapter",
    json={"lora_name": "adapter1"},
)

OpenAI-Compatible API

Use LoRA adapters through the OpenAI-compatible API by specifying the adapter name with a colon separator:
import openai

client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct:lora0",  # base-model:adapter-name
    messages=[{"role": "user", "content": "Hello!"}],
    temperature=0,
    max_tokens=64,
)

print(response.choices[0].message.content)

Advanced Features

GPU Pinning

Pin frequently-used adapters to GPU memory to avoid repeated loading:
python -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --enable-lora \
    --max-loras-per-batch 3 \
    --lora-paths \
        '{"lora_name":"lora0","lora_path":"path/to/lora0","pinned":true}' \
        lora1=path/to/lora1 \
        lora2=path/to/lora2
Pinned adapters occupy GPU memory slots permanently until unloaded. SGLang limits pinned adapters to max-loras-per-batch - 1 to prevent starvation.

Backend Selection

SGLang supports two LoRA backends:

ChunkedSGMV (csgmv)

Default and recommended. Optimized for high concurrency with 20-80% latency improvements.

Triton

Basic Triton-based implementation. Use for compatibility if needed.
# Use ChunkedSGMV (default)
python -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --enable-lora \
    --lora-backend csgmv \
    --max-loras-per-batch 16

# Use Triton backend
python -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --enable-lora \
    --lora-backend triton

Overlap Loading

Overlap LoRA weight loading with GPU computation to hide data movement latency:
python -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --enable-lora \
    --enable-lora-overlap-loading \
    --lora-paths lora0=path/to/lora0 lora1=path/to/lora1 \
    --max-loras-per-batch 2 \
    --max-loaded-loras 4
Enable when:
  • High adapter churn (frequently switching adapters)
  • Large adapter weights (high rank)
  • PCIe-bottlenecked workloads
Benchmarks show ~35% reduction in median TTFT under adversarial conditions.
Pros:
  • Reduces adapter load time impact
  • Hides H2D transfer latency
Cons:
  • Requires pinned CPU memory (limits max-loaded-loras to 2× max-loras-per-batch)
  • Reduces multi-adapter prefill batching (may increase TTFT when load time << prefill time)

Implementation Architecture

SGLang’s LoRA implementation consists of several key components:

LoRAManager

The LoRAManager class coordinates adapter lifecycle:
# Simplified from python/sglang/srt/lora/lora_manager.py:50
class LoRAManager:
    def __init__(
        self,
        base_model,
        max_loras_per_batch,
        lora_backend="triton",
        max_lora_rank=None,
        target_modules=None,
    ):
        # Initialize backend for GEMM kernels
        self.lora_backend = get_backend_from_name(lora_backend)
        
        # Memory pool for adapter weights
        self.memory_pool = LoRAMemoryPool(...)
        
        # Cached adapters
        self.loras = {}  # lora_id -> LoRAAdapter
Source: python/sglang/srt/lora/lora_manager.py:50

Memory Pool

The memory pool manages GPU memory allocation for adapter weights, implementing eviction policies (LRU or FIFO) when the pool is full.

Adapter Format

Adapters must follow the PEFT format with:
  • adapter_config.json - Configuration (rank, target modules, alpha)
  • Weight files - Adapter matrices (A and B)
  • Optional added_tokens.json - Additional vocabulary tokens
Source: python/sglang/srt/lora/lora_config.py:22

Tensor Parallelism

LoRA serving supports tensor parallelism for large models:
python -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-70B-Instruct \
    --enable-lora \
    --tp-size 4 \
    --lora-paths lora0=path/to/adapter
S-LoRA’s tensor sharding strategy partitions adapter matrices across GPUs to balance computation.

Performance Best Practices

Set based on your concurrency needs. Higher values support more concurrent adapters but increase memory usage.
Pin adapters that are accessed in >50% of requests to avoid repeated loading.
The csgmv backend provides 20-80% better latency than triton at high concurrency.
Use lru (default) for workloads with temporal locality. Use fifo for uniform access patterns.
If adapter loading is a bottleneck, enable --enable-lora-overlap-loading.

Limitations

  • Adding tokens to vocabulary is not currently supported
  • All adapters must have compatible ranks ≤ max-lora-rank
  • Target modules must be subset of lora-target-modules

Future Development

Upcoming features tracked in GitHub Issue #2929:
  • Embedding layer LoRA
  • Unified paging for adapters
  • CUTLASS backend for improved performance
  • Expanded target module support