Overview
The Engine class provides an efficient interface for autoregressive token generation with KV caching and tool use support.
Engine Class
from nanochat.engine import Engine
from nanochat.gpt import GPT
from nanochat.tokenizer import get_tokenizer
from nanochat.checkpoint_manager import load_model
# Load model and tokenizer
model = load_model( "~/.cache/nanochat/checkpoints/d26_sft" )
tokenizer = get_tokenizer()
# Create engine
engine = Engine(model, tokenizer)
Constructor
The GPT model instance to use for generation
Tokenizer instance with encode(), decode(), and encode_special() methods (needed for tool use)
Generation Method
generate()
Generator that yields tokens one at a time for autoregressive generation.
# Tokenize input
tokens = tokenizer.encode( "The quick brown fox" )
# Generate tokens
for token_column, token_masks in engine.generate(
tokens,
num_samples = 1 ,
max_tokens = 100 ,
temperature = 1.0 ,
top_k = None ,
seed = 42
):
token = token_column[ 0 ] # Extract from batch dimension
token_text = tokenizer.decode([token])
print (token_text, end = "" , flush = True )
Input token IDs as a list of integers. Must be pre-tokenized.
Number of parallel samples to generate. Uses KV cache cloning for efficiency.
Maximum number of tokens to generate. If None, generates until model outputs end token or reaches context limit.
Sampling temperature. Higher values (e.g., 1.5) make output more random, lower values (e.g., 0.7) more deterministic. Set to 0.0 for greedy decoding.
If set, only sample from the top-k most likely tokens. None means no top-k filtering.
Random seed for reproducible generation
Yields
The generator yields (token_column, token_masks) tuples:
List of length num_samples containing the next token ID for each sample
List of length num_samples with values:
1 if token was sampled from the model
0 if token was forced (e.g., tool use output)
The engine automatically detects and executes Python expressions in tool blocks:
When model generates <|python_start|>, enters tool mode
Collects tokens until <|python_end|>
Evaluates the Python expression safely
Forces <|output_start|>result<|output_end|> tokens
Supported expressions:
Basic arithmetic: 2 + 2, 10 * 5
String methods: 'strawberry'.count('r')
Safety features:
3-second timeout per expression
No dangerous operations (import, exec, eval, file access)
Limited character set for string operations
KV Cache
The engine automatically manages a key-value cache for efficient generation:
Prefill phase
Processes input tokens in a single forward pass with batch size 1
Clone cache
Replicates the KV cache for num_samples parallel generations
Decode loop
Generates one token at a time, updating the cache incrementally
Benefits
Speed Dramatically faster by caching past key/value states
Efficient Sampling Clone cache once for multiple parallel samples
Memory Optimized Only stores compressed KV states, not full activations
Flash Attention 3 Optimized for FA3’s flash_attn_with_kvcache API
Cache Structure
The KVCache stores tensors in Flash Attention 3’s native (B, T, H, D) layout:
Keys: (batch_size, seq_len, n_kv_head, head_dim)
Values: (batch_size, seq_len, n_kv_head, head_dim)
Position tracking via cache_seqlens tensor allows FA3 to update cache in-place.
Helper Functions
sample_next_token()
Samples the next token from logits:
from nanochat.engine import sample_next_token
import torch
logits = torch.randn( 4 , 32768 ) # (batch_size, vocab_size)
rng = torch.Generator()
rng.manual_seed( 42 )
# Sample with temperature
next_tokens = sample_next_token(logits, rng, temperature = 0.8 )
# Greedy decoding
next_tokens = sample_next_token(logits, rng, temperature = 0.0 )
# Top-k sampling
next_tokens = sample_next_token(logits, rng, temperature = 1.0 , top_k = 50 )
Logits tensor of shape (B, vocab_size)
Random number generator for sampling
Sampling temperature. 0.0 for greedy decoding.
If set, only sample from top-k tokens
Usage Examples
Basic Generation
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
from nanochat.tokenizer import get_tokenizer
# Load
model = load_model( "~/.cache/nanochat/checkpoints/d26_sft" )
tokenizer = get_tokenizer()
engine = Engine(model, tokenizer)
# Prepare input
prompt = "Once upon a time"
tokens = tokenizer.encode(prompt)
# Generate
response_tokens = []
for token_column, token_masks in engine.generate(
tokens,
max_tokens = 200 ,
temperature = 1.0
):
token = token_column[ 0 ]
response_tokens.append(token)
print (tokenizer.decode([token]), end = "" , flush = True )
print ()
Chat Conversation
# Build conversation tokens manually
bos = tokenizer.encode_special( "<|bos|>" )
user_start = tokenizer.encode_special( "<|user_start|>" )
user_end = tokenizer.encode_special( "<|user_end|>" )
assistant_start = tokenizer.encode_special( "<|assistant_start|>" )
assistant_end = tokenizer.encode_special( "<|assistant_end|>" )
# Format: <|bos|><|user_start|>message<|user_end|><|assistant_start|>
tokens = [bos, user_start]
tokens.extend(tokenizer.encode( "What is the capital of France?" ))
tokens.extend([user_end, assistant_start])
# Generate response
for token_column, _ in engine.generate(tokens, max_tokens = 100 ):
token = token_column[ 0 ]
if token == assistant_end:
break
print (tokenizer.decode([token]), end = "" , flush = True )
Multiple Samples
# Generate 4 completions in parallel
tokens = tokenizer.encode( "The meaning of life is" )
samples = [[] for _ in range ( 4 )]
for token_column, token_masks in engine.generate(
tokens,
num_samples = 4 ,
max_tokens = 50 ,
temperature = 1.2 ,
seed = 42
):
for i, token in enumerate (token_column):
samples[i].append(token)
# Decode all samples
for i, sample_tokens in enumerate (samples):
text = tokenizer.decode(sample_tokens)
print ( f "Sample { i + 1 } : { text } " )
# The model can use Python calculator
tokens = tokenizer.encode( "Calculate 15 * 8 = " )
for token_column, token_masks in engine.generate(tokens, max_tokens = 50 ):
token = token_column[ 0 ]
was_sampled = token_masks[ 0 ] == 1
token_text = tokenizer.decode([token])
print (token_text, end = "" , flush = True )
# Token mask shows whether this was sampled or forced by tool
if not was_sampled:
print ( " [tool output]" , end = "" )
See Also