Skip to main content

Overview

The Engine class provides an efficient interface for autoregressive token generation with KV caching and tool use support.

Engine Class

from nanochat.engine import Engine
from nanochat.gpt import GPT
from nanochat.tokenizer import get_tokenizer
from nanochat.checkpoint_manager import load_model

# Load model and tokenizer
model = load_model("~/.cache/nanochat/checkpoints/d26_sft")
tokenizer = get_tokenizer()

# Create engine
engine = Engine(model, tokenizer)

Constructor

model
GPT
required
The GPT model instance to use for generation
tokenizer
object
required
Tokenizer instance with encode(), decode(), and encode_special() methods (needed for tool use)

Generation Method

generate()

Generator that yields tokens one at a time for autoregressive generation.
# Tokenize input
tokens = tokenizer.encode("The quick brown fox")

# Generate tokens
for token_column, token_masks in engine.generate(
    tokens,
    num_samples=1,
    max_tokens=100,
    temperature=1.0,
    top_k=None,
    seed=42
):
    token = token_column[0]  # Extract from batch dimension
    token_text = tokenizer.decode([token])
    print(token_text, end="", flush=True)
tokens
list[int]
required
Input token IDs as a list of integers. Must be pre-tokenized.
num_samples
int
default:"1"
Number of parallel samples to generate. Uses KV cache cloning for efficiency.
max_tokens
int
default:"None"
Maximum number of tokens to generate. If None, generates until model outputs end token or reaches context limit.
temperature
float
default:"1.0"
Sampling temperature. Higher values (e.g., 1.5) make output more random, lower values (e.g., 0.7) more deterministic. Set to 0.0 for greedy decoding.
top_k
int
default:"None"
If set, only sample from the top-k most likely tokens. None means no top-k filtering.
seed
int
default:"42"
Random seed for reproducible generation

Yields

The generator yields (token_column, token_masks) tuples:
token_column
list[int]
List of length num_samples containing the next token ID for each sample
token_masks
list[int]
List of length num_samples with values:
  • 1 if token was sampled from the model
  • 0 if token was forced (e.g., tool use output)

Tool Use

The engine automatically detects and executes Python expressions in tool blocks:
  1. When model generates <|python_start|>, enters tool mode
  2. Collects tokens until <|python_end|>
  3. Evaluates the Python expression safely
  4. Forces <|output_start|>result<|output_end|> tokens
Supported expressions:
  • Basic arithmetic: 2 + 2, 10 * 5
  • String methods: 'strawberry'.count('r')
Safety features:
  • 3-second timeout per expression
  • No dangerous operations (import, exec, eval, file access)
  • Limited character set for string operations

KV Cache

The engine automatically manages a key-value cache for efficient generation:
1

Prefill phase

Processes input tokens in a single forward pass with batch size 1
2

Clone cache

Replicates the KV cache for num_samples parallel generations
3

Decode loop

Generates one token at a time, updating the cache incrementally

Benefits

Speed

Dramatically faster by caching past key/value states

Efficient Sampling

Clone cache once for multiple parallel samples

Memory Optimized

Only stores compressed KV states, not full activations

Flash Attention 3

Optimized for FA3’s flash_attn_with_kvcache API

Cache Structure

The KVCache stores tensors in Flash Attention 3’s native (B, T, H, D) layout:
  • Keys: (batch_size, seq_len, n_kv_head, head_dim)
  • Values: (batch_size, seq_len, n_kv_head, head_dim)
Position tracking via cache_seqlens tensor allows FA3 to update cache in-place.

Helper Functions

sample_next_token()

Samples the next token from logits:
from nanochat.engine import sample_next_token
import torch

logits = torch.randn(4, 32768)  # (batch_size, vocab_size)
rng = torch.Generator()
rng.manual_seed(42)

# Sample with temperature
next_tokens = sample_next_token(logits, rng, temperature=0.8)

# Greedy decoding
next_tokens = sample_next_token(logits, rng, temperature=0.0)

# Top-k sampling
next_tokens = sample_next_token(logits, rng, temperature=1.0, top_k=50)
logits
torch.Tensor
required
Logits tensor of shape (B, vocab_size)
rng
torch.Generator
required
Random number generator for sampling
temperature
float
default:"1.0"
Sampling temperature. 0.0 for greedy decoding.
top_k
int
default:"None"
If set, only sample from top-k tokens

Usage Examples

Basic Generation

from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
from nanochat.tokenizer import get_tokenizer

# Load
model = load_model("~/.cache/nanochat/checkpoints/d26_sft")
tokenizer = get_tokenizer()
engine = Engine(model, tokenizer)

# Prepare input
prompt = "Once upon a time"
tokens = tokenizer.encode(prompt)

# Generate
response_tokens = []
for token_column, token_masks in engine.generate(
    tokens,
    max_tokens=200,
    temperature=1.0
):
    token = token_column[0]
    response_tokens.append(token)
    print(tokenizer.decode([token]), end="", flush=True)

print()

Chat Conversation

# Build conversation tokens manually
bos = tokenizer.encode_special("<|bos|>")
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")
assistant_end = tokenizer.encode_special("<|assistant_end|>")

# Format: <|bos|><|user_start|>message<|user_end|><|assistant_start|>
tokens = [bos, user_start]
tokens.extend(tokenizer.encode("What is the capital of France?"))
tokens.extend([user_end, assistant_start])

# Generate response
for token_column, _ in engine.generate(tokens, max_tokens=100):
    token = token_column[0]
    if token == assistant_end:
        break
    print(tokenizer.decode([token]), end="", flush=True)

Multiple Samples

# Generate 4 completions in parallel
tokens = tokenizer.encode("The meaning of life is")

samples = [[] for _ in range(4)]
for token_column, token_masks in engine.generate(
    tokens,
    num_samples=4,
    max_tokens=50,
    temperature=1.2,
    seed=42
):
    for i, token in enumerate(token_column):
        samples[i].append(token)

# Decode all samples
for i, sample_tokens in enumerate(samples):
    text = tokenizer.decode(sample_tokens)
    print(f"Sample {i+1}: {text}")

Tool Use

# The model can use Python calculator
tokens = tokenizer.encode("Calculate 15 * 8 = ")

for token_column, token_masks in engine.generate(tokens, max_tokens=50):
    token = token_column[0]
    was_sampled = token_masks[0] == 1
    
    token_text = tokenizer.decode([token])
    print(token_text, end="", flush=True)
    
    # Token mask shows whether this was sampled or forced by tool
    if not was_sampled:
        print(" [tool output]", end="")

See Also

Build docs developers (and LLMs) love