Skip to main content
TensorRT-LLM allows you to write custom CUDA kernels to optimize specific operations for your use case.

Overview

Custom kernels can be integrated in two ways:
  1. Custom Operations: PyTorch custom ops for the PyTorch backend
  2. TensorRT Plugins: Native TensorRT plugins for the TensorRT backend

Custom PyTorch Operations

For the PyTorch backend, use PyTorch’s custom op registration:
import torch
from torch.utils.cpp_extension import load

# Load custom CUDA kernel
custom_op = load(
    name="custom_op",
    sources=["custom_kernel.cu", "custom_kernel.cpp"],
    extra_cuda_cflags=["-O3"]
)

# Use in your model
class MyModule(torch.nn.Module):
    def forward(self, x):
        return custom_op.my_kernel(x)

Example: Custom Attention Kernel

// custom_attention.cu
#include <torch/extension.h>

__global__ void custom_attention_kernel(
    const float* query,
    const float* key,
    const float* value,
    float* output,
    int seq_len,
    int hidden_dim
) {
    // Implement custom attention logic
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (idx < seq_len * hidden_dim) {
        // Optimized attention computation
        // ...
        output[idx] = /* result */;
    }
}

torch::Tensor custom_attention(
    torch::Tensor query,
    torch::Tensor key,
    torch::Tensor value
) {
    auto output = torch::zeros_like(query);
    
    int threads = 256;
    int blocks = (query.numel() + threads - 1) / threads;
    
    custom_attention_kernel<<<blocks, threads>>>(
        query.data_ptr<float>(),
        key.data_ptr<float>(),
        value.data_ptr<float>(),
        output.data_ptr<float>(),
        query.size(0),
        query.size(1)
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("custom_attention", &custom_attention, "Custom attention kernel");
}

Integration with TensorRT-LLM

Add your custom op to the auto_deploy custom_ops directory:
tensorrt_llm/_torch/auto_deploy/custom_ops/
├── my_custom_op/
│   ├── __init__.py
│   ├── my_kernel.cu
│   └── my_kernel.cpp
Register the operation:
# __init__.py
import torch

torch.library.define("myops::custom_kernel", "(Tensor input) -> Tensor")

@torch.library.impl("myops::custom_kernel", "cuda")
def custom_kernel_impl(input):
    # Call your CUDA kernel
    return _custom_kernel_cuda(input)

TensorRT Plugins (Legacy)

For the TensorRT backend, implement a TensorRT plugin:
from tensorrt_llm.plugin import python_plugin

@python_plugin("MyCustomPlugin", outputs=["output"])
def my_custom_plugin(input_tensor):
    """Custom TensorRT plugin implementation."""
    # Implement plugin logic
    return output_tensor

Performance Considerations

Ensure memory accesses are coalesced for maximum bandwidth:
// Good: Coalesced access
__global__ void coalesced_kernel(float* data, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        data[idx] = /* ... */;  // Adjacent threads access adjacent memory
    }
}
Use shared memory for frequently accessed data:
__global__ void shared_mem_kernel(float* input, float* output, int n) {
    __shared__ float shared_data[256];
    
    int idx = threadIdx.x;
    shared_data[idx] = input[blockIdx.x * 256 + idx];
    __syncthreads();
    
    // Use shared_data for computation
}
Maximize occupancy by balancing register usage and thread blocks:
# Profile with Nsight Compute
ncu --metrics sm__warps_active.avg.pct_of_peak ./my_program

Testing Custom Kernels

import torch

# Test correctness
def test_custom_kernel():
    input = torch.randn(1024, 512, device='cuda')
    
    # Reference implementation
    expected = reference_impl(input)
    
    # Custom kernel
    output = custom_kernel(input)
    
    # Compare
    assert torch.allclose(output, expected, rtol=1e-3)

# Test performance
def benchmark_custom_kernel():
    input = torch.randn(1024, 512, device='cuda')
    
    # Warmup
    for _ in range(10):
        _ = custom_kernel(input)
    
    # Benchmark
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for _ in range(100):
        _ = custom_kernel(input)
    end.record()
    
    torch.cuda.synchronize()
    print(f"Time: {start.elapsed_time(end) / 100:.2f} ms")

Examples in TensorRT-LLM

Study existing custom ops:
  • tensorrt_llm/_torch/custom_ops/fused_moe/ - Fused MoE kernel
  • tensorrt_llm/_torch/cuda_tile_kernels/ - CUDA tile kernels
  • tensorrt_llm/kernels/ - C++ CUDA kernels

Next Steps

Profiling

Profile your custom kernels

Optimization Guide

Optimize kernel performance

Build docs developers (and LLMs) love