@python_plugin("MyOp", outputs=["output"])class MyOpPlugin: def __call__(self, input_tensor): return input_tensor * 2 def get_output_shapes(self, input_shapes): """Define output shapes based on input shapes.""" return [input_shapes[0]] # Output has same shape as input def get_output_dtypes(self, input_dtypes): """Define output data types.""" return [input_dtypes[0]] # Output has same dtype as input
Keep computations on GPU to avoid CPU-GPU transfers:
@python_plugin("EfficientOp", outputs=["output"])def efficient_op(gpu_tensor): # All operations stay on GPU return gpu_tensor.relu().sum(dim=-1)
Workspace Management
Use workspace for temporary buffers:
@python_plugin("WorkspaceOp", outputs=["output"], workspace_size=1024*1024)def workspace_op(input, workspace): # Use workspace for intermediate results temp_buffer = workspace[:input.numel()] # ... computations return output
Type Safety
Annotate types for better error messages:
from typing import Optionalimport torch@python_plugin("TypedOp", outputs=["output"])def typed_op( input: torch.Tensor, scale: float, bias: Optional[torch.Tensor] = None) -> torch.Tensor: result = input * scale if bias is not None: result = result + bias return result
Performance: Python plugins may be slower than native C++ plugins due to Python overhead.Compatibility: Python plugins are only supported with the TensorRT backend.Deployment: Python plugins require the Python runtime at inference time.