Skip to main content

Overview

The nanochat.core_eval module provides functions for evaluating language models on the CORE benchmark, as described in the DCLM paper.

evaluate_core

Evaluate a model on a CORE benchmark task.
def evaluate_core(
    model: GPT,
    tokenizer: Tokenizer,
    data: list[dict],
    device: torch.device,
    task_meta: dict
) -> float
This is an alias for evaluate_task - the main entry point for CORE evaluation.

Core Functions

evaluate_task

Evaluate one task across many examples with distributed support.
def evaluate_task(
    model: GPT,
    tokenizer: Tokenizer,
    data: list[dict],
    device: torch.device,
    task_meta: dict
) -> float

Parameters

model
GPT
required
Language model to evaluate
tokenizer
Tokenizer
required
Tokenizer for encoding prompts
data
list[dict]
required
List of task examples. Format depends on task_type:Multiple choice:
{
    'query': str,
    'choices': list[str],
    'gold': int  # index of correct choice
}
Schema:
{
    'context_options': list[str],
    'continuation': str,
    'gold': int  # index of correct context
}
Language modeling:
{
    'context': str,
    'continuation': str
}
device
torch.device
required
Device to run evaluation on
task_meta
dict
required
Task metadata containing:
  • task_type (str): One of 'multiple_choice', 'schema', or 'language_modeling'
  • num_fewshot (int): Number of few-shot examples to include
  • continuation_delimiter (str): Delimiter between context and continuation (e.g., ' ' or '\n')

Returns

accuracy
float
Mean accuracy across all examples (0.0 to 1.0)

evaluate_example

Evaluate a single example.
@torch.no_grad()
def evaluate_example(
    idx: int,
    model: GPT,
    tokenizer: Tokenizer,
    data: list[dict],
    device: torch.device,
    task_meta: dict
) -> bool

Parameters

Same as evaluate_task, plus:
idx
int
required
Index of the example to evaluate in data

Returns

is_correct
bool
Whether the model’s prediction was correct

Task Types

Multiple Choice

Model chooses among multiple options based on which has the lowest average loss. Format:
{
    'query': 'What is the capital of France?',
    'choices': ['London', 'Paris', 'Berlin', 'Madrid'],
    'gold': 1  # Paris
}
Evaluation:
  • Render all choices with the query prefix
  • Forward each choice through the model
  • Select choice with lowest average loss on the continuation tokens

Schema

Model selects the correct context that leads to a given continuation. Format:
{
    'context_options': ['Context A', 'Context B', 'Context C'],
    'continuation': 'the expected continuation',
    'gold': 1  # Context B
}
Evaluation:
  • Render all contexts with the same continuation
  • Forward each option through the model
  • Select context with lowest average loss on the continuation tokens

Language Modeling

Model must correctly predict all tokens in the continuation. Format:
{
    'context': 'Once upon a time',
    'continuation': ' there was a princess'
}
Evaluation:
  • Render context + continuation
  • Forward through the model
  • Check if argmax predictions match all continuation tokens

Prompt Rendering

render_prompts_mc

Render prompts for multiple choice questions.
def render_prompts_mc(
    item: dict,
    continuation_delimiter: str,
    fewshot_examples: list[dict] | None = None
) -> list[str]
Returns one prompt per choice.

render_prompts_schema

Render prompts for schema questions.
def render_prompts_schema(
    item: dict,
    continuation_delimiter: str,
    fewshot_examples: list[dict] | None = None
) -> list[str]
Returns one prompt per context option.

render_prompts_lm

Render prompts for language modeling tasks.
def render_prompts_lm(
    item: dict,
    continuation_delimiter: str,
    fewshot_examples: list[dict] | None = None
) -> list[str]
Returns [prompt_without_continuation, prompt_with_continuation].

Utility Functions

find_common_length

Find the length of common prefix or suffix across token sequences.
def find_common_length(
    token_sequences: list[list[int]],
    direction: str = 'left'
) -> int
Parameters:
  • token_sequences: List of tokenized sequences
  • direction: 'left' for prefix, 'right' for suffix
Returns: Length of common prefix/suffix

stack_sequences

Stack token sequences with padding.
def stack_sequences(
    tokens: list[list[int]],
    pad_token_id: int
) -> torch.Tensor
Pads sequences to the longest length and returns a tensor of shape (B, max_len).

forward_model

Forward model and compute losses and predictions.
@torch.no_grad()
def forward_model(
    model: GPT,
    input_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
Returns: (losses, predictions) where:
  • losses: (B, T) tensor of cross-entropy losses
  • predictions: (B, T) tensor of argmax predictions
Note: Last column of losses is set to NaN (no autoregressive target).

Example Usage

import torch
from nanochat.core_eval import evaluate_task
from nanochat.checkpoint_manager import load_model

# Load model
model, tokenizer, _ = load_model(
    source='base',
    device=torch.device('cuda'),
    phase='eval'
)

# Prepare task data
data = [
    {
        'query': 'The capital of France is',
        'choices': [' London', ' Paris', ' Berlin'],
        'gold': 1
    },
    # ... more examples
]

task_meta = {
    'task_type': 'multiple_choice',
    'num_fewshot': 5,
    'continuation_delimiter': ''
}

# Evaluate
accuracy = evaluate_task(
    model=model,
    tokenizer=tokenizer,
    data=data,
    device=torch.device('cuda'),
    task_meta=task_meta
)

print(f"Accuracy: {accuracy * 100:.2f}%")

Distributed Evaluation

When running with torchrun, the evaluation automatically distributes examples across ranks:
# Each rank processes examples[rank::world_size]
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1

for idx in range(rank, len(data), world_size):
    is_correct = evaluate_example(idx, ...)
    # ...

# Results are synced via all_reduce
if world_size > 1:
    dist.all_reduce(correct, op=dist.ReduceOp.SUM)

Notes

  • Few-shot examples are sampled randomly with seed 1234 + idx for reproducibility
  • Models with max_seq_len attribute will have prompts truncated to that length
  • Sequences are truncated from the left to preserve the continuation tokens
  • Uses BOS token as pad token during batching

Build docs developers (and LLMs) love