evaluate_pipeline_stages
def evaluate_pipeline_stages(
state: PipelineState,
config: PipelineConfig,
max_eval_batches: Optional[int] = 100,
) -> Path
Evaluate all pipeline stages (Base, SFT, DPO) and save comparison results.
This function loads each checkpoint from the pipeline state, computes perplexity on WikiText-2 validation data, and generates both JSON and CSV reports comparing the stages.
Parameters
Pipeline state containing checkpoint paths for pretrain, SFT, and DPO stages.
Pipeline configuration with tokenizer name, sequence length, and run name.
max_eval_batches
Optional[int]
default:"100"
Maximum number of batches to evaluate per stage. Set to None to evaluate the full dataset.
Returns
Path to the saved JSON results file (typically experiments/results/{run_name}_eval.json).
Side effects
- Creates
experiments/results/ directory if it doesn’t exist
- Saves JSON results to
experiments/results/{run_name}_eval.json
- Saves CSV comparison table to
experiments/results/{run_name}_comparison.csv
- Logs perplexity and loss for each stage
Usage
from pathlib import Path
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.evaluation.pipeline_eval import evaluate_pipeline_stages
from modern_llm.config import PipelineConfig
# After running alignment pipeline
config = PipelineConfig(
run_name="modern-llm-experiment",
tokenizer_name="gpt2",
max_seq_len=512,
)
pipeline = AlignmentPipeline(config)
state = pipeline.run()
# Evaluate all stages
results_path = evaluate_pipeline_stages(
state=state,
config=config,
max_eval_batches=50, # Faster evaluation
)
print(f"Results saved to: {results_path}")
# Output: Results saved to: experiments/results/modern-llm-experiment_eval.json
compute_perplexity
def compute_perplexity(
model: ModernDecoderLM,
dataloader: DataLoader,
device: torch.device,
max_batches: Optional[int] = None,
) -> tuple[float, float]
Compute perplexity on a dataset by averaging cross-entropy loss.
Mathematical definition
perplexity=exp(N1i=1∑NLi)
where Li is the cross-entropy loss for batch i and N is the number of batches.
Parameters
Language model to evaluate. Must be on the specified device and in eval mode.
DataLoader providing batches with input_ids, attention_mask, and labels.
Device to run evaluation on (typically torch.device("cuda") or torch.device("cpu")).
max_batches
Optional[int]
default:"None"
Maximum number of batches to evaluate. If None, evaluates the entire dataset.
Returns
Computed perplexity value. Returns inf if all batches had NaN/inf loss.
Average cross-entropy loss across batches. Returns inf if all batches had NaN/inf loss.
Preconditions
- Model must be on the specified device and in eval mode
- DataLoader batches must contain valid input tensors
Complexity
- Time: O(N × L × D²) where N is number of batches, L is sequence length, D is model dimension
- Space: O(B × L) where B is batch size
Usage
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.data.lm_datasets import load_causal_lm_dataset, LanguageModelingDatasetConfig
from modern_llm.evaluation.pipeline_eval import compute_perplexity
from modern_llm.utils.checkpointing import load_checkpoint
from modern_llm.config import ModernLLMConfig
# Load model from checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = load_checkpoint(Path("checkpoints/final.pt"))
model_config = ModernLLMConfig(**ckpt["config"])
model = ModernDecoderLM(model_config)
model.load_state_dict(ckpt["model_state"])
model.to(device)
model.eval()
# Prepare evaluation data
tokenizer = AutoTokenizer.from_pretrained("gpt2")
eval_config = LanguageModelingDatasetConfig(
dataset_name="wikitext",
dataset_config_name="wikitext-2-raw-v1",
split="validation",
max_length=512,
)
eval_dataset = load_causal_lm_dataset(eval_config, tokenizer)
eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False)
# Compute perplexity
perplexity, loss = compute_perplexity(
model=model,
dataloader=eval_dataloader,
device=device,
max_batches=100,
)
print(f"Perplexity: {perplexity:.2f}")
print(f"Average Loss: {loss:.4f}")
StageMetrics
@dataclass
class StageMetrics:
stage: str
perplexity: float
loss: float
gsm8k_em: float = 0.0
gsm8k_em_verifier: float = 0.0
num_params: int = 0
Metrics for a single pipeline stage (Base, SFT, or DPO).
Attributes
Stage name (“base”, “sft”, or “dpo”).
Perplexity on WikiText-2 validation set.
Average cross-entropy loss on WikiText-2 validation set.
GSM8K exact match accuracy (0-1 range).
GSM8K exact match accuracy with verifier reranking (0-1 range).
Total number of trainable parameters in the model.
PipelineEvalResults
@dataclass
class PipelineEvalResults:
base: Optional[StageMetrics] = None
sft: Optional[StageMetrics] = None
dpo: Optional[StageMetrics] = None
verifier_accuracy: float = 0.0
Full evaluation results across all pipeline stages.
Attributes
Metrics for the pretrained base model.
Metrics for the supervised fine-tuned model.
Metrics for the DPO-aligned model.
Verifier model accuracy on answer correctness classification.
Methods
to_dict
def to_dict(self) -> dict
Convert results to a dictionary for serialization.
save
def save(self, path: Path) -> None
Save results to a JSON file. Creates parent directories if needed.
to_csv
def to_csv(self, path: Path) -> None
Save stage comparison table as a CSV file.
Usage
from pathlib import Path
from modern_llm.evaluation.pipeline_eval import (
PipelineEvalResults,
StageMetrics,
)
# Create results
results = PipelineEvalResults(
base=StageMetrics(
stage="base",
perplexity=28.5,
loss=3.35,
num_params=124_000_000,
),
sft=StageMetrics(
stage="sft",
perplexity=22.1,
loss=3.10,
gsm8k_em=0.45,
num_params=124_000_000,
),
dpo=StageMetrics(
stage="dpo",
perplexity=24.3,
loss=3.19,
gsm8k_em=0.52,
num_params=124_000_000,
),
verifier_accuracy=0.78,
)
# Save as JSON
results.save(Path("experiments/results/eval.json"))
# Save as CSV for comparison
results.to_csv(Path("experiments/results/comparison.csv"))
# Access metrics
print(f"Base PPL: {results.base.perplexity:.1f}")
print(f"SFT PPL: {results.sft.perplexity:.1f}")
print(f"DPO PPL: {results.dpo.perplexity:.1f}")
print(f"PPL improvement: {results.base.perplexity - results.sft.perplexity:.1f}")