Overview
Adding a custom model involves four main steps:- Model Configuration - Define the configuration class
- Model Definition - Implement the model architecture
- Weight Loading - Handle checkpoint loading and conversion
- Model Registration - Register the model for auto-discovery
If your model is already supported in HuggingFace Transformers, you can reuse the HuggingFace configuration and adapt the modeling code.
Prerequisites
- Working TensorRT-LLM installation
- Understanding of PyTorch and Transformer architectures
- Familiarity with your target model architecture
Step 1: Model Configuration
Create a configuration class that defines all model hyperparameters.- Reuse HuggingFace Config
- Custom Config
If your model exists in HuggingFace Transformers, reuse their config:Or extend it with custom parameters:
configuration_mymodel.py
Step 2: Model Definition
Implement your model architecture using TensorRT-LLM’s base classes and modules.Model Structure
A typical decoder model consists of:- Attention Layer - Inherits from
Attention - Decoder Layer - Inherits from
DecoderLayer - Model - Inherits from
DecoderModel - Model for Causal LM - Inherits from
DecoderModelForCausalLM
Implementation Example
modeling_mymodel.py
Key Components
Attention Module
Attention Module
The Features:
Attention module handles all attention computation:- Automatic KV cache management
- Support for GQA, MQA, MHA
- RoPE, ALiBi position embeddings
- Flash Attention integration
- Tensor parallelism
Optimized Modules
Optimized Modules
Use TensorRT-LLM’s optimized modules for best performance:Benefits:
- Automatic tensor parallelism
- Quantization support
- Optimized CUDA kernels
- Reduced memory usage
Input Format
Input Format
Packed Mode: Input tensors are in packed format where the first dimension is the total number of tokens in the batch:AttentionMetadata: Contains batching and KV cache metadata:
- Sequence lengths
- Cumulative sequence lengths
- KV cache locations
- Attention mask information
attn_metadata to all attention modules.Step 3: Weight Loading
Implement weight loading to convert checkpoint weights to your model format.Default Weight Loading
The baseDecoderModelForCausalLM provides automatic weight loading that works for most models:
Custom Weight Loading
If you need custom weight mapping, overrideload_weights:
Weight Loading Examples
Step 4: Model Registration
Register your model so it can be auto-discovered by TensorRT-LLM.- Out-of-Tree Model (Recommended)
- Core Model (Advanced)
For custom models, use out-of-tree registration without modifying TensorRT-LLM:1. Create your model files:2. Add 3. Import in your script:
@register_auto_model decorator:modeling_mymodel.py
main.py
Using Your Custom Model
Once registered, use your model like any other TensorRT-LLM model:Complete Example
See the full out-of-tree model example:Testing Your Model
Best Practices
Model Architecture
Model Architecture
- Inherit from base classes: Use
DecoderModel,DecoderLayer,Attentionfor compatibility - Use optimized modules: Prefer TensorRT-LLM modules over PyTorch for better performance
- Handle packed input: Ensure your model handles packed token sequences correctly
- Pass attn_metadata: Always pass
AttentionMetadatato attention modules
Weight Loading
Weight Loading
- Use module-level load_weights: Call
Linear.load_weights(),Embedding.load_weights()for TP/quantization - Handle weight fusion: Fuse QKV, gate/up projections for better performance
- Test with different TP sizes: Verify weight loading works with tensor parallelism
- Support quantization: Ensure weight loader handles quantized checkpoints
Performance Optimization
Performance Optimization
- Fuse operations: Combine QKV, gate/up projections into single Linear layers
- Use RMSNorm: Faster than LayerNorm when applicable
- Enable tensor parallelism: Use
Linear(..., tp_group=...)for automatic TP - Optimize memory: Use
torch.utils.checkpointfor gradient checkpointing during training
Testing
Testing
- Unit test components: Test attention, layer, model separately
- Compare with reference: Validate outputs against HuggingFace/original implementation
- Test all TP configurations: Test with TP=1, 2, 4, 8
- Benchmark performance: Measure throughput and latency vs baselines
Common Issues
Shape Mismatches
Shape Mismatches
Problem: Tensor shape errors in forward passSolution: Remember that inputs are in packed mode:
Weight Loading Fails
Weight Loading Fails
Problem: Weights don’t load correctlySolution: Check weight names and shapes:
TP Not Working
TP Not Working
Problem: Tensor parallelism produces wrong resultsSolution: Ensure all Linear layers have TP config:
Example Models
Study these reference implementations:LLaMA
Standard decoder-only model with RoPE and GQA
GPT
Classic GPT architecture with learned positions
Gemma
Model with custom normalization and sliding window
OPT (Out-of-tree)
Complete out-of-tree model example
Next Steps
Model Configuration
Configure your custom model
Quantization
Add quantization support
Deployment
Deploy your model to production