Skip to main content
Python plugins allow you to define custom TensorRT operations using Python, which are then compiled into the TensorRT engine.
Python plugins are primarily used with the TensorRT backend. For the PyTorch backend, use custom CUDA kernels instead.

Overview

Python plugins provide a way to extend TensorRT-LLM with custom operations without writing C++ code.

Creating a Python Plugin

Use the @python_plugin decorator from tensorrt_llm.plugin:
from tensorrt_llm.plugin import python_plugin
import torch

@python_plugin("MyCustomOp", outputs=["output"])
def my_custom_op(input_tensor, scale: float = 1.0):
    """
    Custom operation that scales input tensor.
    
    Args:
        input_tensor: Input tensor to scale
        scale: Scaling factor
        
    Returns:
        Scaled tensor
    """
    return input_tensor * scale

Plugin Registration

The @python_plugin decorator handles:
  • Automatic registration with TensorRT plugin registry
  • Type inference for inputs and outputs
  • Shape inference during engine build
  • Serialization for engine save/load

Advanced Plugin Example

import torch
from tensorrt_llm.plugin import python_plugin

@python_plugin(
    "CustomAttention",
    outputs=["attention_output", "attention_weights"],
    workspace_size=1024 * 1024  # 1MB workspace
)
def custom_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: torch.Tensor = None
):
    """
    Custom attention mechanism with optional masking.
    """
    # Compute attention scores
    scores = torch.matmul(query, key.transpose(-2, -1))
    scores = scores / math.sqrt(query.size(-1))
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax and weighted sum
    weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(weights, value)
    
    return output, weights

Using Plugins in Models

from tensorrt_llm.functional import python_plugin_call

class MyModel(nn.Module):
    def forward(self, x):
        # Call Python plugin
        output = python_plugin_call(
            "MyCustomOp",
            inputs=[x],
            scale=2.0
        )
        return output

Plugin Parameters

name
string
required
Unique name for the plugin
outputs
List[string]
required
List of output names
workspace_size
integer
default:"0"
Workspace memory size in bytes for intermediate calculations
plugin_namespace
string
default:"trtllm"
Plugin namespace for organization
plugin_version
string
default:"1"
Plugin version identifier

Shape Inference

Implement shape inference for dynamic shapes:
@python_plugin("MyOp", outputs=["output"])
class MyOpPlugin:
    def __call__(self, input_tensor):
        return input_tensor * 2
    
    def get_output_shapes(self, input_shapes):
        """Define output shapes based on input shapes."""
        return [input_shapes[0]]  # Output has same shape as input
    
    def get_output_dtypes(self, input_dtypes):
        """Define output data types."""
        return [input_dtypes[0]]  # Output has same dtype as input

Best Practices

Keep computations on GPU to avoid CPU-GPU transfers:
@python_plugin("EfficientOp", outputs=["output"])
def efficient_op(gpu_tensor):
    # All operations stay on GPU
    return gpu_tensor.relu().sum(dim=-1)
Use workspace for temporary buffers:
@python_plugin("WorkspaceOp", outputs=["output"], workspace_size=1024*1024)
def workspace_op(input, workspace):
    # Use workspace for intermediate results
    temp_buffer = workspace[:input.numel()]
    # ... computations
    return output
Annotate types for better error messages:
from typing import Optional
import torch

@python_plugin("TypedOp", outputs=["output"])
def typed_op(
    input: torch.Tensor,
    scale: float,
    bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
    result = input * scale
    if bias is not None:
        result = result + bias
    return result

Plugin Examples

Custom Activation Function

@python_plugin("SwishActivation", outputs=["output"])
def swish_activation(x: torch.Tensor, beta: float = 1.0):
    """Swish activation: x * sigmoid(beta * x)"""
    return x * torch.sigmoid(beta * x)

Fused Operations

@python_plugin("FusedLayerNormGELU", outputs=["output"])
def fused_ln_gelu(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
    """Fused LayerNorm + GELU activation."""
    # LayerNorm
    normalized = F.layer_norm(input, input.shape[-1:], weight, bias)
    
    # GELU
    return F.gelu(normalized)

Debugging Plugins

Enable plugin debugging:
import os
os.environ['TRTLLM_DEBUG_PLUGINS'] = '1'

# Your plugin will now print debug information
View plugin calls during inference:
from tensorrt_llm.logger import logger
logger.set_level('debug')

# Plugin execution will be logged

Limitations

Performance: Python plugins may be slower than native C++ plugins due to Python overhead.Compatibility: Python plugins are only supported with the TensorRT backend.Deployment: Python plugins require the Python runtime at inference time.

Migration to C++ Plugins

For production deployments, consider migrating to C++ plugins:
// custom_plugin.cpp
class MyPluginCreator : public BaseCreator {
    const char* getPluginName() const override {
        return "MyCustomOp";
    }
    
    IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override {
        // Create plugin instance
    }
};

REGISTER_TENSORRT_PLUGIN(MyPluginCreator);

Next Steps

Custom Kernels

Learn about CUDA kernel development

Model Architecture

Understand TensorRT-LLM architecture

Build docs developers (and LLMs) love