Skip to main content
The Engine class provides efficient inference for NanoChat models using KV caching and supports advanced features like tool use and multi-sample generation.

Overview

The engine is designed for maximum efficiency:
  • KV Cache: Stores key-value pairs from previous tokens to avoid recomputation
  • Streaming Generation: Yields tokens one at a time for real-time output
  • Batch Generation: Generate multiple samples in parallel
  • Tool Use: Built-in calculator tool with automatic result injection

Basic Usage

from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model

# Load model and tokenizer
model, tokenizer, meta = load_model("sft", device, phase="eval")

# Create engine
engine = Engine(model, tokenizer)

# Generate tokens
prompt_tokens = tokenizer.encode("What is 2+2?", prepend=bos_token_id)
for token_column, token_masks in engine.generate(prompt_tokens, num_samples=1, max_tokens=100):
    token = token_column[0]
    print(tokenizer.decode([token]), end="", flush=True)

Generation Methods

Streaming Generation

generate(tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42) Streaming generator that yields tokens one at a time. Parameters:
  • tokens (list[int]): Input token sequence
  • num_samples (int): Number of parallel samples to generate (default: 1)
  • max_tokens (int): Maximum tokens to generate (default: None = unlimited)
  • temperature (float): Sampling temperature, 0.0 = greedy (default: 1.0)
  • top_k (int): Top-k sampling parameter (default: None)
  • seed (int): Random seed (default: 42)
Yields:
  • token_column (list[int]): Next token for each sample (length = num_samples)
  • token_masks (list[int]): 1 if sampled, 0 if forced by tool (length = num_samples)
Example:
for token_column, token_masks in engine.generate(
    prompt_tokens,
    num_samples=4,  # Generate 4 samples in parallel
    max_tokens=256,
    temperature=0.8,
    top_k=50,
    seed=12345
):
    for i, (token, mask) in enumerate(zip(token_column, token_masks)):
        if mask == 1:
            print(f"Sample {i}: {tokenizer.decode([token])}")
        else:
            print(f"Sample {i}: [FORCED] {tokenizer.decode([token])}")

Batch Generation

generate_batch(tokens, num_samples=1, **kwargs) Non-streaming batch generation that returns complete token sequences. Returns:
  • results (list[list[int]]): Token sequences for each sample
  • masks (list[list[int]]): Mask sequences (1=sampled, 0=forced)
Example:
results, masks = engine.generate_batch(
    prompt_tokens,
    num_samples=4,
    max_tokens=128,
    temperature=0.7
)

for i, (tokens, mask) in enumerate(zip(results, masks)):
    text = tokenizer.decode(tokens)
    print(f"Sample {i}: {text}")

KV Cache

The KV cache stores key-value pairs from attention layers to avoid recomputing them for previous tokens.

Architecture

From nanochat/engine.py:83-133:
class KVCache:
    """
    KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
    
    Key differences from FA2-style cache:
    - Tensors are (B, T, H, D) not (B, H, T, D)
    - FA3 updates the cache in-place during flash_attn_with_kvcache
    - Position tracked per batch element via cache_seqlens tensor
    """
    
    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
        self.batch_size = batch_size
        self.max_seq_len = seq_len
        self.n_layers = num_layers
        self.n_heads = num_heads
        self.head_dim = head_dim
        # Pre-allocate cache tensors: (n_layers, B, T, H, D)
        self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
        self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
        # Current sequence length per batch element (FA3 needs int32)
        self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)

Key Methods

  • reset(): Reset cache to empty state
  • get_pos(): Get current position (assumes all batch elements at same position)
  • get_layer_cache(layer_idx): Return (k_cache, v_cache) views for a specific layer
  • advance(num_tokens): Advance the cache position by num_tokens
  • prefill(other): Copy cached KV from another cache (used for multi-sample generation)

Prefill-then-Decode Pattern

The engine uses an efficient two-phase approach:
  1. Prefill: Process the entire prompt in batch=1
  2. Decode: Clone the KV cache for each sample and generate in parallel
From nanochat/engine.py:194-218:
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
    batch_size=1,
    seq_len=len(tokens),
    device=device,
    dtype=dtype,
    **kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :].expand(num_samples, -1)  # (num_samples, vocab_size)

# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
    batch_size=num_samples,
    seq_len=kv_length_hint,
    device=device,
    dtype=dtype,
    **kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill  # no need to keep this memory around
This approach processes the prompt once and then generates multiple diverse samples efficiently.

Token Sampling

The engine uses a custom sampling function that supports temperature and top-k sampling. From nanochat/engine.py:135-152:
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
    """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
    assert temperature >= 0.0, "temperature must be non-negative"
    if temperature == 0.0:
        return torch.argmax(logits, dim=-1, keepdim=True)
    if top_k is not None and top_k > 0:
        k = min(top_k, logits.size(-1))
        vals, idx = torch.topk(logits, k, dim=-1)
        vals = vals / temperature
        probs = F.softmax(vals, dim=-1)
        choice = torch.multinomial(probs, num_samples=1, generator=rng)
        return idx.gather(1, choice)
    else:
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1, generator=rng)
Sampling Modes:
  • temperature=0.0: Greedy decoding (always pick most likely token)
  • temperature=1.0: Standard sampling from full distribution
  • temperature>1.0: More random/creative (flattens distribution)
  • temperature<1.0: More focused/deterministic (sharpens distribution)
  • top_k: Only sample from top-k most likely tokens

Tool Use: Calculator

The engine includes built-in support for a calculator tool. When the model generates special tokens, the engine automatically evaluates expressions and injects results.

How It Works

  1. Model generates <|python_start|> token
  2. Engine enters “python block” mode and accumulates tokens
  3. Model generates <|python_end|> token
  4. Engine evaluates the expression using use_calculator()
  5. If successful, engine forces <|output_start|> + result + <|output_end|> tokens
  6. Model continues generation with the result in context
From nanochat/engine.py:251-267:
# Handle tool logic
if next_token == python_start:
    state.in_python_block = True
    state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
    state.in_python_block = False
    if state.python_expr_tokens:
        expr = self.tokenizer.decode(state.python_expr_tokens)
        result = use_calculator(expr)
        if result is not None:
            result_tokens = self.tokenizer.encode(str(result))
            state.forced_tokens.append(output_start)
            state.forced_tokens.extend(result_tokens)
            state.forced_tokens.append(output_end)
    state.python_expr_tokens = []
elif state.in_python_block:
    state.python_expr_tokens.append(next_token)

Supported Expressions

The calculator supports:
  • Math expressions: 2+2, 3.14*10, 100/5
  • String operations: "hello".count("l"), "world".count("o")
Safety features:
  • Timeout after 3 seconds
  • No access to builtins or dangerous operations
  • Disallows power operator **
  • Sanitizes input to prevent code injection
From nanochat/engine.py:47-80:
def use_calculator(expr):
    """
    Evaluate a Python expression safely.
    Supports both math expressions and string operations like .count()
    """
    # Remove commas from numbers
    expr = expr.replace(",", "")
    
    # Check if it's a pure math expression
    if all([x in "0123456789*+-/.() " for x in expr]):
        if "**" in expr:  # disallow power operator
            return None
        return eval_with_timeout(expr)
    
    # Check if it's a string operation we support
    allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
    if not all([x in allowed_chars for x in expr]):
        return None
    
    # Disallow dangerous patterns
    dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
                         'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
                         'getattr', 'setattr', 'delattr', 'hasattr']
    expr_lower = expr.lower()
    if any(pattern in expr_lower for pattern in dangerous_patterns):
        return None
    
    # Only allow .count() method for now
    if '.count(' not in expr:
        return None
    
    return eval_with_timeout(expr)

Row State Tracking

When generating multiple samples in parallel, the engine maintains per-row state to track tool use independently. From nanochat/engine.py:155-162:
class RowState:
    # Per-row state tracking during generation
    def __init__(self, current_tokens=None):
        self.current_tokens = current_tokens or []  # Current token sequence for this row
        self.forced_tokens = deque()  # Queue of tokens to force inject
        self.in_python_block = False  # Whether we are inside a python block
        self.python_expr_tokens = []  # Tokens of the current python expression
        self.completed = False  # Whether this row has completed generation
Each sample maintains:
  • current_tokens: Full token history
  • forced_tokens: Queue of tokens to inject (from tool results)
  • in_python_block: Whether currently inside <|python_start|><|python_end|>
  • python_expr_tokens: Accumulated expression tokens
  • completed: Whether generation has ended for this sample

Performance Testing

The engine includes a built-in test to verify correctness and benchmark performance.
python -m nanochat.engine
This compares the engine’s output against the model’s naive generation function and reports timing. From nanochat/engine.py:302-357:
if __name__ == "__main__":
    """
    Quick inline test to make sure that the naive/slow model.generate function
    is equivalent to the faster Engine.generate function here.
    """
    # Load model
    model, tokenizer, meta = load_model("base", device, phase="eval")
    bos_token_id = tokenizer.get_bos_token_id()
    kwargs = dict(max_tokens=64, temperature=0.0)
    prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
    
    # Generate with reference implementation
    generated_tokens = []
    torch.cuda.synchronize()
    t0 = time.time()
    stream = model.generate(prompt_tokens, **kwargs)
    with autocast_ctx:
        for token in stream:
            generated_tokens.append(token)
    torch.cuda.synchronize()
    t1 = time.time()
    print(f"Reference time: {t1 - t0:.2f}s")
    reference_ids = generated_tokens
    
    # Generate with Engine
    generated_tokens = []
    engine = Engine(model, tokenizer)
    stream = engine.generate(prompt_tokens, num_samples=1, **kwargs)
    torch.cuda.synchronize()
    t0 = time.time()
    with autocast_ctx:
        for token_column, token_masks in stream:
            token = token_column[0]
            generated_tokens.append(token)
    torch.cuda.synchronize()
    t1 = time.time()
    print(f"Engine time: {t1 - t0:.2f}s")
    
    # Compare
    print(f"Match: {reference_ids == generated_tokens}")

Complete Example: Multi-Sample Generation

import torch
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, autodetect_device_type

# Initialize
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("sft", device, phase="eval")
engine = Engine(model, tokenizer)

# Prepare prompt
bos = tokenizer.get_bos_token_id()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")

tokens = [bos, user_start]
tokens.extend(tokenizer.encode("Tell me a joke"))
tokens.extend([user_end, assistant_start])

# Generate 4 different jokes in parallel
results, masks = engine.generate_batch(
    tokens,
    num_samples=4,
    max_tokens=200,
    temperature=1.0,
    top_k=50,
    seed=42
)

for i, (result_tokens, mask) in enumerate(zip(results, masks)):
    # Only decode the assistant's response (after assistant_start)
    response_start = len(tokens)
    response_tokens = result_tokens[response_start:]
    text = tokenizer.decode(response_tokens)
    print(f"\n=== Sample {i+1} ===")
    print(text)
This efficiently generates 4 diverse responses by processing the prompt once and then sampling 4 times in parallel.

Build docs developers (and LLMs) love