Skip to main content

Kernel Development

This guide covers developing custom CUDA kernels for SGLang, including Triton kernels and CUDA C++ kernels.

Overview

SGLang uses highly optimized kernels for:
  • Attention: FlashAttention, FlashInfer
  • GEMM: Matrix multiplication (via cuBLAS, cutlass)
  • Elementwise ops: RMSNorm, SiLU, RoPE
  • Sampling: Top-k, top-p, softmax
Kernel Location:
  • Triton kernels: python/sglang/srt/layers/
  • CUDA kernels: sgl-kernel package (separate repository)

Why Custom Kernels?

Custom kernels provide:
  • Performance: 2-10x speedup over PyTorch ops
  • Memory efficiency: Fused operations reduce memory bandwidth
  • Flexibility: Implement custom operators not in PyTorch

Triton Kernels

Introduction to Triton

Triton is a Python DSL for writing GPU kernels. It’s easier than CUDA C++ but still offers high performance.

Example: Fused RMSNorm

RMSNorm (Root Mean Square Layer Normalization) is commonly used in modern LLMs.

Unfused Implementation (PyTorch)

def rmsnorm_pytorch(x, weight, eps=1e-6):
    """RMSNorm using PyTorch ops."""
    variance = x.pow(2).mean(-1, keepdim=True)
    x = x * torch.rsqrt(variance + eps)
    return x * weight
Problem: Multiple kernel launches, high memory bandwidth

Fused Triton Kernel

import triton
import triton.language as tl

@triton.jit
def rmsnorm_kernel(
    x_ptr,      # Pointer to input
    weight_ptr, # Pointer to weight
    output_ptr, # Pointer to output
    stride,     # Stride for input/output
    N,          # Hidden dimension
    eps,        # Epsilon for numerical stability
    BLOCK_SIZE: tl.constexpr,
):
    # Get program ID
    pid = tl.program_id(0)
    
    # Compute row offset
    row_start = pid * stride
    
    # Load input row
    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    x = tl.load(x_ptr + row_start + offsets, mask=mask, other=0.0)
    
    # Compute variance
    variance = tl.sum(x * x, axis=0) / N
    rstd = 1.0 / tl.sqrt(variance + eps)
    
    # Load weight
    weight = tl.load(weight_ptr + offsets, mask=mask, other=1.0)
    
    # Normalize and scale
    output = x * rstd * weight
    
    # Store output
    tl.store(output_ptr + row_start + offsets, output, mask=mask)

def rmsnorm_triton(x, weight, eps=1e-6):
    """RMSNorm using Triton kernel."""
    batch_size, hidden_dim = x.shape
    output = torch.empty_like(x)
    
    # Launch kernel
    BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
    grid = (batch_size,)
    
    rmsnorm_kernel[grid](
        x, weight, output,
        stride=hidden_dim,
        N=hidden_dim,
        eps=eps,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    
    return output
Performance: ~3x faster than PyTorch

Triton Best Practices

1. Use Power-of-2 Block Sizes

BLOCK_SIZE = triton.next_power_of_2(N)  # Good
BLOCK_SIZE = N  # Bad if N is not power of 2

2. Coalesce Memory Accesses

# Good: Consecutive threads access consecutive memory
offsets = tl.arange(0, BLOCK_SIZE)
data = tl.load(ptr + offsets)

# Bad: Strided access
offsets = tl.arange(0, BLOCK_SIZE) * stride
data = tl.load(ptr + offsets)

3. Minimize Synchronization

# Avoid barriers if possible
tl.debug_barrier()  # Use sparingly

4. Optimize Occupancy

# Tune BLOCK_SIZE for occupancy
for BLOCK_SIZE in [128, 256, 512, 1024]:
    benchmark(BLOCK_SIZE)

CUDA C++ Kernels

For maximum performance, write CUDA C++ kernels in the sgl-kernel package.

Example: Fused Add + ReLU

#include <cuda_runtime.h>
#include <cuda_fp16.h>

// Kernel: Fused element-wise add and ReLU
__global__ void fused_add_relu_kernel(
    const half* x,
    const half* y,
    half* out,
    int N
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (idx < N) {
        half sum = __hadd(x[idx], y[idx]);  // FP16 add
        out[idx] = __hmax(sum, __float2half(0.0f));  // ReLU
    }
}

// Host function
void fused_add_relu(
    const half* x,
    const half* y,
    half* out,
    int N
) {
    int threads = 256;
    int blocks = (N + threads - 1) / threads;
    
    fused_add_relu_kernel<<<blocks, threads>>>(x, y, out, N);
}

PyTorch Binding

#include <torch/extension.h>

void fused_add_relu(
    const half* x,
    const half* y,
    half* out,
    int N
);

torch::Tensor fused_add_relu_torch(
    torch::Tensor x,
    torch::Tensor y
) {
    auto out = torch::empty_like(x);
    
    fused_add_relu(
        reinterpret_cast<const half*>(x.data_ptr()),
        reinterpret_cast<const half*>(y.data_ptr()),
        reinterpret_cast<half*>(out.data_ptr()),
        x.numel()
    );
    
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_add_relu", &fused_add_relu_torch);
}

Build System

# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="my_kernels",
    ext_modules=[
        CUDAExtension(
            "my_kernels",
            ["my_kernels.cu", "bindings.cpp"],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": ["-O3", "--use_fast_math"],
            },
        )
    ],
    cmdclass={"build_ext": BuildExtension},
)

FlashAttention Integration

SGLang uses FlashInfer for optimized attention.

Using FlashInfer

from flashinfer import single_prefill_with_kv_cache, batch_decode_with_padded_kv_cache

# Prefill
output = single_prefill_with_kv_cache(
    q=query,           # [seq_len, num_heads, head_dim]
    k=key_cache,       # [seq_len, num_kv_heads, head_dim]
    v=value_cache,     # [seq_len, num_kv_heads, head_dim]
    causal=True,
)

# Decode
output = batch_decode_with_padded_kv_cache(
    q=query,           # [batch_size, num_heads, head_dim]
    k=key_cache,       # [batch_size, max_seq_len, num_kv_heads, head_dim]
    v=value_cache,     # [batch_size, max_seq_len, num_kv_heads, head_dim]
    seq_lens=seq_lens, # [batch_size]
)

Custom Attention Backend

To add a new attention backend:
  1. Create attention class:
# python/sglang/srt/layers/attention/my_attention.py

class MyAttention:
    def __init__(self, num_heads, head_dim, num_kv_heads):
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.num_kv_heads = num_kv_heads
    
    def forward(self, q, k, v, **kwargs):
        # Implement attention
        return output
  1. Register backend:
# python/sglang/srt/layers/attention/__init__.py

from sglang.srt.layers.attention.my_attention import MyAttention

ATTENTION_BACKENDS = {
    "flashinfer": FlashInferAttention,
    "flashattn": FlashAttention,
    "my_backend": MyAttention,
}
  1. Use it:
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B \
  --attention-backend my_backend

Kernel Optimization Techniques

1. Tiling

Break computation into tiles that fit in shared memory:
__global__ void matmul_tiled(
    const float* A,
    const float* B,
    float* C,
    int M, int N, int K
) {
    __shared__ float As[TILE_SIZE][TILE_SIZE];
    __shared__ float Bs[TILE_SIZE][TILE_SIZE];
    
    int row = blockIdx.y * TILE_SIZE + threadIdx.y;
    int col = blockIdx.x * TILE_SIZE + threadIdx.x;
    float sum = 0.0f;
    
    // Loop over tiles
    for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
        // Load tiles into shared memory
        if (row < M && t * TILE_SIZE + threadIdx.x < K)
            As[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];
        else
            As[threadIdx.y][threadIdx.x] = 0.0f;
        
        if (col < N && t * TILE_SIZE + threadIdx.y < K)
            Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];
        else
            Bs[threadIdx.y][threadIdx.x] = 0.0f;
        
        __syncthreads();
        
        // Compute partial sum
        for (int k = 0; k < TILE_SIZE; k++)
            sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
        
        __syncthreads();
    }
    
    if (row < M && col < N)
        C[row * N + col] = sum;
}

2. Vectorized Loads

Load multiple elements per thread:
// Load 4 floats at once using float4
__global__ void vector_add(
    const float* a,
    const float* b,
    float* c,
    int N
) {
    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
    
    if (idx + 3 < N) {
        float4 a_vec = *reinterpret_cast<const float4*>(&a[idx]);
        float4 b_vec = *reinterpret_cast<const float4*>(&b[idx]);
        float4 c_vec;
        
        c_vec.x = a_vec.x + b_vec.x;
        c_vec.y = a_vec.y + b_vec.y;
        c_vec.z = a_vec.z + b_vec.z;
        c_vec.w = a_vec.w + b_vec.w;
        
        *reinterpret_cast<float4*>(&c[idx]) = c_vec;
    }
}

3. Warp Shuffle

Communicate within a warp without shared memory:
// Warp-level reduction
__device__ float warp_reduce_sum(float val) {
    for (int offset = 16; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

__global__ void reduce_sum(
    const float* input,
    float* output,
    int N
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    float sum = (idx < N) ? input[idx] : 0.0f;
    
    // Warp-level reduction
    sum = warp_reduce_sum(sum);
    
    // First thread in warp writes result
    if (threadIdx.x % 32 == 0) {
        atomicAdd(output, sum);
    }
}

Profiling Kernels

Nsight Compute

# Profile specific kernel
ncu --kernel-name "my_kernel" --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed \
    python -m sglang.bench_one_batch --model meta-llama/Llama-3.2-1B

# Full metrics
ncu --set full -o profile python script.py

# Open in GUI
ncu-ui profile.ncu-rep

Key Metrics

  • SM Throughput: Streaming Multiprocessor utilization
  • Memory Throughput: DRAM bandwidth utilization
  • Occupancy: Active warps / max warps
  • Register Usage: Registers per thread
  • Shared Memory Usage: Bytes per block

Testing Kernels

Correctness Test

import torch
import my_kernels

def test_correctness():
    x = torch.randn(1024, 4096, dtype=torch.float16, device="cuda")
    weight = torch.randn(4096, dtype=torch.float16, device="cuda")
    
    # Reference (PyTorch)
    ref = rmsnorm_pytorch(x, weight)
    
    # Custom kernel
    out = my_kernels.rmsnorm(x, weight)
    
    # Check
    torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-3)
    print("Correctness test passed!")

test_correctness()

Performance Test

import time

def benchmark_kernel(fn, *args, warmup=10, iters=100):
    # Warmup
    for _ in range(warmup):
        fn(*args)
    torch.cuda.synchronize()
    
    # Benchmark
    start = time.time()
    for _ in range(iters):
        fn(*args)
    torch.cuda.synchronize()
    elapsed = time.time() - start
    
    return elapsed / iters

# Compare
time_pytorch = benchmark_kernel(rmsnorm_pytorch, x, weight)
time_triton = benchmark_kernel(my_kernels.rmsnorm, x, weight)

print(f"PyTorch: {time_pytorch*1000:.3f} ms")
print(f"Triton:  {time_triton*1000:.3f} ms")
print(f"Speedup: {time_pytorch/time_triton:.2f}x")

Adding Kernels to sgl-kernel

See Contribution Guide for the multi-PR workflow.

Step 1: Add Kernel Implementation

cd sglang/sgl-kernel
# Add your kernel
vim csrc/my_kernel.cu

Step 2: Submit PR

Submit PR to sgl-kernel without using it yet.

Step 3: Bump Version

Submit another PR to bump sgl-kernel version. This triggers PyPI release.

Step 4: Use Kernel

Update pyproject.toml in sglang and use the new kernel.

Resources

Next Steps