TokenizerBase
The TokenizerBase class defines the tokenizer interface used by TensorRT-LLM. It extends the Hugging Face PreTrainedTokenizerBase protocol.
Overview
TensorRT-LLM uses tokenizers to convert text to token IDs (encoding) and token IDs back to text (decoding). The default implementation wraps Hugging Face transformers tokenizers.
from tensorrt_llm import LLM
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")
# Access the tokenizer
tokenizer = llm.tokenizer
# Encode text
token_ids = tokenizer.encode("Hello, world!")
print(token_ids) # [1, 9906, 11, 1917, 0]
# Decode token IDs
text = tokenizer.decode(token_ids)
print(text) # "Hello, world!"
Loading Tokenizers
Automatic Loading
Tokenizers are automatically loaded when creating an LLM instance:
# Load from Hugging Face Hub
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")
# Load from local directory
llm = LLM(model="/path/to/model")
# Custom tokenizer path
llm = LLM(
model="/path/to/model",
tokenizer="/path/to/tokenizer"
)
Skip Tokenizer Initialization
If you plan to work with token IDs directly:
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
skip_tokenizer_init=True
)
# tokenizer will be None
assert llm.tokenizer is None
# Provide token IDs directly
token_ids = [1, 9906, 11, 1917, 0] # "Hello, world!"
output = llm.generate(token_ids)
Using a Pre-loaded Tokenizer
from transformers import AutoTokenizer
from tensorrt_llm import LLM
# Load tokenizer separately
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
# Pass to LLM
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
tokenizer=tokenizer
)
Core Methods
encode()
Convert text to token IDs.
token_ids = tokenizer.encode(
"Hello, world!",
add_special_tokens=True
)
print(token_ids) # [1, 9906, 11, 1917, 0]
Parameters
Whether to add special tokens (BOS, EOS) during encoding.
Returns
decode()
Convert token IDs back to text.
text = tokenizer.decode(
[1, 9906, 11, 1917, 0],
skip_special_tokens=True
)
print(text) # "Hello, world!"
Parameters
Whether to remove special tokens (BOS, EOS, PAD) from output.
spaces_between_special_tokens
Whether to add spaces between special tokens in output.
Returns
batch_encode_plus()
Encode multiple texts in a batch.
encoded = tokenizer.batch_encode_plus(
["Hello, world!", "How are you?"],
padding=True,
return_tensors="pt"
)
print(encoded["input_ids"])
Parameters
Returns
Dictionary containing:
input_ids: Token IDs
attention_mask: Attention mask
- Other tokenizer-specific outputs
decode_incrementally()
Incremental decoding for streaming generation. This method is optimized for streaming scenarios where tokens are generated one at a time.
prev_text = ""
states = None
for new_token_ids in streaming_tokens:
text, states = tokenizer.decode_incrementally(
new_token_ids,
prev_text=prev_text,
states=states,
skip_special_tokens=True
)
print(text[len(prev_text):], end="", flush=True) # Print only new text
prev_text = text
Parameters
Incremental token IDs to decode.
Previously decoded text. None for first iteration.
Internal decoding state from previous iteration. None for first iteration.
Force flush pending tokens to output.
Whether to skip special tokens in output.
spaces_between_special_tokens
Whether to add spaces between special tokens.
Iteration interval for streaming updates.
Returns
Tuple of:
text: Current decoded text
states: Updated decoding state (pass to next iteration)
Properties
End-of-sequence token ID.print(tokenizer.eos_token_id) # 2
Padding token ID.print(tokenizer.pad_token_id) # 0
Model name or path the tokenizer was loaded from.print(tokenizer.name_or_path) # "meta-llama/Llama-3.1-8B-Instruct"
Whether this is a fast (Rust-based) tokenizer.print(tokenizer.is_fast) # True
Chat Templates
apply_chat_template()
Format a conversation using the model’s chat template.
conversation = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is machine learning?"},
{"role": "assistant", "content": "Machine learning is..."},
{"role": "user", "content": "Can you explain more?"},
]
prompt = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
output = llm.generate(prompt)
Parameters
conversation
List[Dict[str, str]]
required
List of message dictionaries with "role" and "content" keys.
If True, return token IDs. If False, return formatted string.
Add prompt for the next assistant message.
Returns
Formatted prompt as string or token IDs (depending on tokenize parameter).
Custom Tokenizers
You can implement custom tokenizers by inheriting from TokenizerBase:
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
from typing import List
class CustomTokenizer(TokenizerBase):
def __init__(self, vocab_file: str):
# Load your custom vocabulary
self.vocab = self._load_vocab(vocab_file)
self._eos_token_id = 2
self._pad_token_id = 0
@property
def eos_token_id(self) -> int:
return self._eos_token_id
@property
def pad_token_id(self) -> int:
return self._pad_token_id
def encode(self, text: str, **kwargs) -> List[int]:
# Your encoding logic
return [self.vocab.get(word, 0) for word in text.split()]
def decode(self, token_ids: List[int], **kwargs) -> str:
# Your decoding logic
return " ".join([self.vocab_inv.get(tid, "<unk>") for tid in token_ids])
# Use custom tokenizer
custom_tokenizer = CustomTokenizer("vocab.txt")
llm = LLM(
model="/path/to/model",
tokenizer=custom_tokenizer
)
The default tokenizer implementation that wraps Hugging Face transformers tokenizers:
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from transformers import AutoTokenizer
# Create from HF tokenizer
hf_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
tokenizer = TransformersTokenizer(hf_tokenizer)
# Or use from_pretrained class method
tokenizer = TransformersTokenizer.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
trust_remote_code=True
)
Utility Functions
load_hf_tokenizer()
Load a Hugging Face tokenizer directly:
from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer
tokenizer = load_hf_tokenizer(
model_path="meta-llama/Llama-3.1-8B-Instruct",
trust_remote_code=True,
use_fast=True
)
Environment Variables
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND
Backend for incremental detokenization:
'HF': Use Hugging Face tokenizers backend (faster for small stream intervals)
'TRTLLM': Use TensorRT-LLM backend
TLLM_STREAM_INTERVAL_THRESHOLD
Threshold for switching between HF and TRTLLM incremental detokenization backends.
Usage Examples
Basic Encoding and Decoding
from tensorrt_llm import LLM
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")
tokenizer = llm.tokenizer
# Encode
text = "The quick brown fox jumps over the lazy dog"
token_ids = tokenizer.encode(text)
print(f"Token IDs: {token_ids}")
print(f"Number of tokens: {len(token_ids)}")
# Decode
decoded = tokenizer.decode(token_ids, skip_special_tokens=True)
print(f"Decoded: {decoded}")
Streaming with Incremental Decoding
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")
tokenizer = llm.tokenizer
# Generate with streaming
future = llm.generate_async(
"Write a poem about AI",
sampling_params=SamplingParams(max_tokens=200),
streaming=True
)
for partial_output in future:
# text_diff uses incremental decoding internally
new_text = partial_output.outputs[0].text_diff
print(new_text, end="", flush=True)
from tensorrt_llm import LLM
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "What's the capital of France?"},
]
# Format with chat template
prompt = llm.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
print("Formatted prompt:")
print(prompt)
# Generate response
output = llm.generate(prompt)
print("\nResponse:")
print(output.outputs[0].text)
Batch Encoding
texts = [
"Hello, world!",
"How are you today?",
"Machine learning is fascinating."
]
encoded = tokenizer.batch_encode_plus(
texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
print(f"Input IDs shape: {encoded['input_ids'].shape}")
print(f"Attention mask shape: {encoded['attention_mask'].shape}")
See Also