Skip to main content

Overview

Nanochat implements automatic compute-optimal scaling based on empirical scaling laws. The core insight: one parameter (depth) controls everything. As you increase --depth, the training script automatically adjusts:
  • Model size (parameters)
  • Training tokens (data)
  • Batch size
  • Learning rates
  • Weight decay
This follows the muP (maximal update parameterization) philosophy: tune hyperparameters at a small scale, then transfer to larger models.

The Depth Dial

Model size is controlled by a single parameter:
--depth=12  # ~100M params, reference model
--depth=20  # ~300M params  
--depth=28  # ~600M params
Model dimension scales linearly:
base_dim = depth × aspect_ratio  # aspect_ratio=64 by default
model_dim = round_to_multiple(base_dim, head_dim)  # head_dim=128
num_heads = model_dim / head_dim
Example: depth=12base_dim=768model_dim=768num_heads=6

Scaling Law 1: Optimal Training Tokens

The compute-optimal data:param ratio is determined empirically:
optimal_tokens = target_param_data_ratio × scaling_params
# scaling_params = transformer_matrices + lm_head parameters
Default: target_param_data_ratio = 10.5 Why 10.5? Derived from scaling laws experiments (see runs/scaling_laws.sh). This differs from Chinchilla (20:1) because:
  • Smaller models benefit from more data per parameter
  • Different architecture (sliding windows, value embeddings)
  • Empirically optimal for nanochat’s parameter count range

Parameter Counting

Only certain parameters count toward the scaling ratio:
scaling_params = (
    transformer_matrices +  # Q, K, V, O, MLP weights
    lm_head                 # output projection
)
# Excluded: embeddings, value_embeds, scalars
This gives cleaner scaling laws (see dev/LOG.md Jan 27, 2026).

Scaling Law 2: Optimal Batch Size

Follows the Power Lines paper (arXiv:2505.13738):
B_opt = B_ref × (D / D_ref)^0.383
Where:
  • B_ref = 524,288 tokens (optimal batch size at d12)
  • D_ref = optimal training tokens for d12
  • D = optimal training tokens for current depth
The exponent 0.383 means:
  • Doubling training tokens → 1.3× larger batch size
  • 10× more tokens → 2.4× larger batch size
Result is rounded to nearest power of 2 for efficiency.

Why This Matters

Using too small a batch size:
  • Wastes wall-clock time (more iterations needed)
  • Hurts convergence (noisy gradients)
Using too large a batch size:
  • Hurts generalization (“generalization gap”)
  • Wastes compute (diminishing returns)
The optimal batch size balances these trade-offs.

Scaling Law 3: Learning Rate Scaling

When batch size changes, learning rates scale as:
lr_scale = √(B / B_ref)
This follows standard AdamW scaling (and assumed for Muon). Example: If batch size doubles from 524K to 1M tokens, LRs increase by √2 ≈ 1.41×. Why square root?
  • SGD: Linear scaling (lr ∝ B)
  • AdamW: Square root scaling (lr ∝ √B)
  • Muon: Assumed same as AdamW (not studied carefully)

Scaling Law 4: Weight Decay Scaling

Follows the T_epoch framework (arXiv:2405.13698):
λ = λ_ref × √(B / B_ref) × (D_ref / D)
Where T_epoch = B / (η · λ · D) is kept constant. Intuition: As you train longer (larger D), you need less regularization (smaller λ). Example:
  • d12: λ = 0.2
  • d20 (2.5× more tokens): λ ≈ 0.08

Reference Model (d12)

All scaling is anchored to depth=12:
D_REF = target_param_data_ratio × get_scaling_params(d12_model)
B_REF = 524_288  # 2^19 tokens
Hyperparameters are tuned at d12, then transferred via scaling laws.

Running Scaling Laws Experiments

The runs/scaling_laws.sh script trains models at different depths and FLOP budgets:
# Edit these arrays in scaling_laws.sh
FLOPS_BUDGETS=(1e18 2.15e18 4.64e18 1e19)
DEPTHS=(8 10 12 14 16 18 20)

# Run the sweep
./runs/scaling_laws.sh
For each (FLOP budget, depth) pair:
  1. Train: --target-flops=$flops --depth=$d
  2. Evaluate: Validation loss and CORE metric
  3. Log: Results saved to CSV

Output

Results saved to $NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}/results.csv:
flops_budget,depth,model_dim,params_total,num_iterations,val_bpb,core_score,train_time_sec
1e18,8,512,45000000,1200,2.156,0.234,1800
1e18,12,768,103000000,800,2.089,0.267,1950
...

Analysis

Plot validation loss vs FLOPs for different depths:
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv("results.csv")

for depth in df['depth'].unique():
    subset = df[df['depth'] == depth]
    plt.plot(subset['flops_budget'], subset['val_bpb'], 
             marker='o', label=f'd={depth}')

plt.xscale('log')
plt.xlabel('Compute Budget (FLOPs)')
plt.ylabel('Validation Loss (bits/byte)')
plt.legend()
plt.show()
Expected result: Compute-optimal frontier where each FLOP budget has an optimal depth.

Compute-Optimal Training

The goal: For a fixed compute budget (FLOPs), find the optimal (model_size, training_tokens) pair.

Fixed FLOPs Constraint

FLOPs = model_size × training_tokens × C
# where C ≈ 6 (forward pass) + 2× backprop ≈ 20 FLOPs/param/token
For fixed FLOPs:
  • Larger model + fewer tokens
  • Smaller model + more tokens
These are equivalent in compute but have different loss!

Chinchilla vs Nanochat

Chinchilla (DeepMind, 2022):
  • Ratio: 20 tokens per parameter
  • Example: 10B params → 200B tokens
Nanochat:
  • Ratio: 10.5 tokens per parameter
  • Example: 100M params → 1B tokens
Why the difference?
  • Chinchilla studied 400M-70B param models
  • Nanochat studies 30M-600M param models
  • Smaller models benefit from more parameters per token (less data needed to amortize params)

Automatic Hyperparameter Scaling

When you run:
torchrun --nproc_per_node=8 -m scripts.base_train --depth=20
The script automatically:
  1. Calculates model size: params = f(depth, aspect_ratio, head_dim, ...)
  2. Determines optimal tokens: D = 10.5 × scaling_params
  3. Computes optimal batch size: B = B_ref × (D/D_ref)^0.383
  4. Scales learning rates: lr = lr_ref × √(B/B_ref)
  5. Adjusts weight decay: λ = λ_ref × √(B/B_ref) × (D_ref/D)
  6. Calculates num_iterations: num_iters = D / B
You can override any of these with explicit flags:
--num-iterations=5000           # override training length
--total-batch-size=262144       # override batch size
--matrix-lr=0.03                # override learning rate

Logged Output Example

Model config:
  n_layer: 20
  n_embd: 1280
  n_head: 10

Parameter counts:
  transformer_matrices: 245,760,000
  lm_head: 41,943,040
  scaling_params: 287,703,040
  total: 323,472,384

Auto-computed optimal batch size: 524,288 tokens
Calculated number of iterations: 6,112
Total training tokens: 3,203,743,744
Tokens : Scaling params ratio: 11.14

Scaling LRs by 1.0000 for batch size 524,288
Scaling weight decay from 0.200000 to 0.200000 for depth 20

Validation: Bits Per Byte

Instead of cross-entropy loss (bits per token), nanochat reports bits per byte:
loss_per_token = cross_entropy(logits, targets)
bytes_per_token = token_bytes[targets]  # lookup table
loss_per_byte = (loss_per_token × tokens) / bytes_per_token.sum()
Why? Bits per byte is invariant to tokenizer vocabulary size, making it easier to compare models with different tokenizers.

Override Parameters

You can override automatic scaling with explicit flags:

Override training length

--num-iterations=10000                    # explicit step count
--target-flops=1e19                       # train to fixed FLOPs
--target-param-data-ratio=20              # use Chinchilla ratio

Override batch size

--total-batch-size=262144  # half the default

Override learning rates

--matrix-lr=0.03 \
--embedding-lr=0.4 \
--unembedding-lr=0.005
Learning rate scaling still applies unless you disable it.

Muon Momentum Schedule

Muon optimizer uses a momentum warmup (independent of depth):
momentum = 0.850.95  (over first 300 steps)
This is NOT scaled with depth (based on optimizer dynamics, not model size).

Summary

Scaling laws in nanochat:
WhatScales AsReference
Model dimdepth × 64Linear
Training tokens10.5 × scaling_paramsEmpirical
Batch sizeB_ref × (D/D_ref)^0.383Power Lines
Learning ratelr_ref × √(B/B_ref)AdamW theory
Weight decayλ_ref × √(B/B_ref) × (D_ref/D)T_epoch
Num iterationstokens / batch_sizeDerived
All anchored to d12 reference model.

Further Reading

  • Chinchilla paper - Original compute-optimal scaling laws
  • Power Lines paper - Optimal batch size scaling
  • T_epoch paper - Weight decay scaling
  • muP paper - Hyperparameter transfer across model sizes
  • dev/LOG.md in nanochat source - Detailed scaling laws experiments and findings

Build docs developers (and LLMs) love