Skip to main content

Adding Models

This guide covers how to add support for new model architectures to SGLang.

Prerequisites

  • Understanding of the model architecture you want to add
  • Access to the model’s Hugging Face implementation or source code
  • Familiarity with PyTorch and transformer models
  • SGLang development environment set up (Development Setup)

Overview

Adding a new model to SGLang typically involves:
  1. Creating a model implementation file
  2. Registering the model architecture
  3. Adding tests
  4. Updating documentation

Step 1: Create Model Implementation

File Location

Create a new file in python/sglang/srt/models/ named after your model (e.g., my_model.py).

Model Structure

A typical model implementation includes:
from typing import Optional, Tuple
import torch
import torch.nn as nn
from sglang.srt.models.model_base import ModelBase, InputMetadata

class MyModelForCausalLM(ModelBase):
    """Implementation of MyModel for causal language modeling."""
    
    def __init__(self, config):
        super().__init__(config)
        # Initialize model layers
        self.model = MyModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
        **kwargs
    ) -> torch.Tensor:
        """Forward pass.
        
        Args:
            input_ids: Input token IDs [batch_size, seq_len]
            positions: Token positions [batch_size, seq_len]
            input_metadata: Metadata for attention and caching
        
        Returns:
            logits: [batch_size, seq_len, vocab_size]
        """
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            input_metadata=input_metadata,
            **kwargs
        )
        logits = self.lm_head(hidden_states)
        return logits
    
    def load_weights(self, weights: dict):
        """Load weights from checkpoint."""
        # Implement weight loading logic
        pass

Key Components

Model Layers

Implement the core model architecture:
class MyModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([
            MyDecoderLayer(config) for _ in range(config.num_hidden_layers)
        ])
        self.norm = nn.LayerNorm(config.hidden_size)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        
        for layer in self.layers:
            hidden_states = layer(
                hidden_states,
                positions,
                input_metadata,
            )
        
        return self.norm(hidden_states)

Attention Layer

Implement attention using SGLang’s optimized attention:
from sglang.srt.layers.attention import Attention

class MyDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = Attention(
            config.hidden_size,
            config.num_attention_heads,
            config.num_key_value_heads,
            head_dim=config.hidden_size // config.num_attention_heads,
        )
        self.mlp = MyMLP(config)
        self.input_layernorm = nn.LayerNorm(config.hidden_size)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        # Self attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states,
            positions,
            input_metadata,
        )
        hidden_states = residual + hidden_states
        
        # MLP
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        
        return hidden_states

Weight Loading

Implement weight loading from Hugging Face checkpoints:
def load_weights(self, weights: dict):
    """Load weights from Hugging Face checkpoint."""
    params_dict = dict(self.named_parameters())
    
    for name, loaded_weight in weights.items():
        # Handle weight name mapping if needed
        if "qkv_proj" in name:
            # Split QKV weights if needed
            q, k, v = loaded_weight.chunk(3, dim=0)
            # Load individual weights
            params_dict[name.replace("qkv_proj", "q_proj")].data.copy_(q)
            params_dict[name.replace("qkv_proj", "k_proj")].data.copy_(k)
            params_dict[name.replace("qkv_proj", "v_proj")].data.copy_(v)
        else:
            param = params_dict[name]
            param.data.copy_(loaded_weight)

Step 2: Register Model

Add your model to the model registry in python/sglang/srt/model_loader/loader.py:
from sglang.srt.models.my_model import MyModelForCausalLM

_MODEL_REGISTRY = {
    # ... existing models ...
    "MyModelForCausalLM": MyModelForCausalLM,
}

Step 3: Add Configuration

If your model has a custom configuration, create a config class:
from transformers import PretrainedConfig

class MyModelConfig(PretrainedConfig):
    model_type = "my_model"
    
    def __init__(
        self,
        vocab_size=32000,
        hidden_size=4096,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=8,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads

Step 4: Add Tests

Create tests in test/srt/test_my_model.py:
import unittest
from sglang import Engine
from sglang.test.test_utils import DEFAULT_PROMPTS

class TestMyModel(unittest.TestCase):
    def test_generate(self):
        """Test basic generation."""
        engine = Engine(
            model_path="org/my-model",
            trust_remote_code=True,
        )
        
        outputs = engine.generate(
            prompts=DEFAULT_PROMPTS[:2],
            sampling_params={"max_new_tokens": 32}
        )
        
        self.assertEqual(len(outputs), 2)
        for output in outputs:
            self.assertIn("text", output)
            self.assertGreater(len(output["text"]), 0)
    
    def test_batch_generation(self):
        """Test batched generation."""
        engine = Engine(model_path="org/my-model")
        
        outputs = engine.generate(
            prompts=DEFAULT_PROMPTS,
            sampling_params={"max_new_tokens": 16, "temperature": 0.8}
        )
        
        self.assertEqual(len(outputs), len(DEFAULT_PROMPTS))

if __name__ == "__main__":
    unittest.main()

Step 5: Test Your Model

Manual Testing

# Launch server
python -m sglang.launch_server \
  --model-path org/my-model \
  --trust-remote-code

# Test with OpenAI client
python -c "
import openai
client = openai.OpenAI(
    base_url='http://localhost:30000/v1',
    api_key='EMPTY'
)
response = client.chat.completions.create(
    model='org/my-model',
    messages=[{'role': 'user', 'content': 'Hello!'}]
)
print(response.choices[0].message.content)
"

Run Unit Tests

python -m pytest test/srt/test_my_model.py -v

Step 6: Optimize Performance

Use Fused Kernels

Replace standard operations with optimized kernels:
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear

class MyDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Use fused RMSNorm instead of LayerNorm
        self.input_layernorm = RMSNorm(config.hidden_size)
        
        # Use fused QKV projection
        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            config.hidden_size // config.num_attention_heads,
            config.num_attention_heads,
            config.num_key_value_heads,
        )

Enable CUDA Graphs

Ensure your model supports CUDA graphs by avoiding dynamic operations in the forward pass.

Step 7: Add Documentation

Update the documentation:
  1. Add model to supported models list
  2. Create example usage in docs
  3. Document any special requirements or configuration

Example Documentation

## MyModel

**Architecture**: Transformer decoder with grouped-query attention

**Variants**:
- `org/my-model-7b` - 7B parameter model
- `org/my-model-13b` - 13B parameter model

**Example**:
```bash
python -m sglang.launch_server \
  --model-path org/my-model-7b \
  --trust-remote-code
Special Features:
  • Supports GQA (Grouped-Query Attention)
  • Requires trust_remote_code=True

## Advanced Topics

### Multimodal Models

For vision-language models, implement image processing:

```python
class MyVLMForCausalLM(ModelBase):
    def __init__(self, config):
        super().__init__(config)
        self.vision_model = MyVisionModel(config)
        self.language_model = MyLanguageModel(config)
        self.projector = nn.Linear(config.vision_hidden_size, config.hidden_size)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
        pixel_values: Optional[torch.Tensor] = None,
        **kwargs
    ) -> torch.Tensor:
        # Process images
        if pixel_values is not None:
            vision_features = self.vision_model(pixel_values)
            vision_features = self.projector(vision_features)
            # Merge with text embeddings
            hidden_states = self.merge_vision_text(
                input_ids, vision_features, input_metadata
            )
        else:
            hidden_states = self.language_model.embed_tokens(input_ids)
        
        # Continue with language model
        return self.language_model(hidden_states, positions, input_metadata)

MoE Models

For Mixture-of-Experts models, use SGLang’s MoE layers:
from sglang.srt.layers.moe import MoE

class MyMoELayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.moe = MoE(
            num_experts=config.num_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
        )
    
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.moe(hidden_states)

Troubleshooting

Model Not Loading

  • Check model registration in loader.py
  • Verify model_type in config matches registry
  • Ensure trust_remote_code=True if needed

OOM (Out of Memory)

  • Reduce batch size
  • Enable memory optimizations:
    python -m sglang.launch_server \
      --model-path org/my-model \
      --mem-fraction-static 0.8
    

Slow Performance

  • Enable CUDA graphs: Remove --disable-cuda-graph
  • Use tensor parallelism: --tp-size 2
  • Profile with nsight: See Benchmark and Profiling

Checklist

Before submitting your model:
  • Model implementation complete
  • Weights load correctly from Hugging Face
  • Unit tests pass
  • Manual testing successful
  • Documentation updated
  • Pre-commit hooks pass
  • Performance acceptable

Resources

Next Steps