Skip to main content
A standard bf16 GPT model at 512 dimensions and 9 layers would be far too large to fit in 16 MB. The baseline achieves a viable artifact size by quantizing model weights to int8 and compressing the result with zlib.

Why quantization is necessary

A bf16 parameter occupies 2 bytes. Even a small model with ~8 M parameters would require ~16 MB in raw bf16 — exactly at the limit with nothing left for the code. int8 cuts that to 1 byte per parameter, and zlib compression on top typically yields a further 2–3× reduction for structured weight tensors. The baseline numbers illustrate this:
  • Raw bf16 model: significantly over 16 MB
  • After int8 + zlib (level 9): 15,815,847 bytes
  • Code (train_gpt.py): 47,642 bytes
  • Total: 15,863,489 bytes — under the 16,000,000-byte cap

The compression pipeline

1

Train in bf16/fp32

The model trains with weight matrices stored in fp32 (cast to bf16 at matmul time via CastedLinear). Control tensors (resid_mix, attn_scale, mlp_scale, q_gain, skip_weights) are kept in fp32 throughout.
2

Quantize state dict to int8

After training, quantize_state_dict_int8() processes every tensor in the state dict. 2D float tensors get per-row int8 quantization; small tensors are kept in fp16; control tensors stay in fp32.
3

Serialize with torch.save

The quantized dict is serialized into an in-memory io.BytesIO buffer using torch.save.
4

Compress with zlib level 9

The serialized bytes are compressed with zlib.compress(quant_raw, level=9) and written to final_model.int8.ptz.
5

Size check

os.path.getsize("final_model.int8.ptz") + len(code.encode("utf-8")) must be under 16,000,000.
6

Roundtrip evaluation

The artifact is decompressed and dequantized via dequantize_state_dict_int8(), loaded back into the model, and the official final_int8_zlib_roundtrip val_bpb is computed.

quantize_state_dict_int8()

The function applies three different treatments depending on the tensor:
def quantize_state_dict_int8(state_dict: dict[str, Tensor]):
    for name, tensor in state_dict.items():
        t = tensor.detach().to("cpu").contiguous()

        if not t.is_floating_point():
            # Non-float tensors (e.g. integer buffers): exact passthrough
            passthrough[name] = t
            continue

        # Small float tensors (<= 65536 elements): passthrough as fp16
        if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL:
            kept = keep_float_tensor(name, t, passthrough_orig_dtypes)
            passthrough[name] = kept
            continue

        # Large float tensors: quantize to int8
        q, s = quantize_float_tensor(t)
        quantized[name] = q
        scales[name] = s

Per-row int8 for 2D tensors

2D tensors (weight matrices) receive one scale per output row. This tracks per-channel magnitude variation much better than a single tensor-wide scale:
def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]:
    t32 = t.float()
    if t32.ndim == 2:
        clip_abs = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1)
        clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None])
        scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0)
        q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8)
        return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous()
Scales are stored as fp16 (INT8_PER_ROW_SCALE_DTYPE = torch.float16).

Outlier clipping

A high-percentile clip is applied before quantizing to suppress weight outliers without discarding most values:
INT8_CLIP_PERCENTILE = 99.99984
INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0
With 99.99984%, roughly 1 in 625,000 values is clipped per row — aggressive enough to remove extreme outliers while preserving nearly all weight information.

Control tensor treatment

The following tensor name patterns are never quantized — they are kept in fp32 regardless of size:
CONTROL_TENSOR_NAME_PATTERNS = (
    "attn_scale", "attn_scales",
    "mlp_scale",  "mlp_scales",
    "resid_mix",  "resid_mixes",
    "q_gain",
    "skip_weight", "skip_weights",
)
These patterns are also used during training to keep those parameters in fp32 via restore_low_dim_params_to_fp32().
You can override the control tensor patterns via the CONTROL_TENSOR_NAME_PATTERNS environment variable (comma-separated). INT8_KEEP_FLOAT_FP32_NAME_PATTERNS controls which patterns are kept in fp32 during quantization specifically (defaults to the same set).

dequantize_state_dict_int8()

The roundtrip restores each tensor to its original dtype:
def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]:
    for name, q in obj["quantized"].items():
        dtype = getattr(torch, obj["dtypes"][name])
        s = obj["scales"][name]
        if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0:
            s = s.to(dtype=torch.float32)
            out[name] = (
                q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))
            ).to(dtype=dtype).contiguous()
        else:
            scale = float(s.item())
            out[name] = (q.float() * scale).to(dtype=dtype).contiguous()
    for name, t in obj["passthrough"].items():
        out_t = t.detach().to("cpu").contiguous()
        orig_dtype = passthrough_orig_dtypes.get(name)
        if isinstance(orig_dtype, str):
            out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous()
        out[name] = out_t
For per-row quantized tensors, the row scale is broadcast back across the trailing dimensions before casting to the original dtype.

Artifact format

The file final_model.int8.ptz contains a zlib-compressed torch.save payload with the following top-level keys:
KeyContents
__quant_format__Format identifier: "int8_clean_per_row_v1"
quantizedDict of int8 tensors (large float parameters)
scalesDict of fp16 row scales, one per quantized tensor
dtypesOriginal dtype string for each quantized tensor
passthroughDict of non-quantized tensors (fp32 control params, fp16 small params, non-floats)
qmetaPer-tensor quantization scheme metadata (per_row)
passthrough_orig_dtypesOriginal dtype for passthrough tensors downcast from bf16/fp32

Size calculation

The training script logs all sizing information at the end of a run:
Serialized model int8+zlib: 15815847 bytes (payload:... raw_torch:... payload_ratio:...)
Total submission size int8+zlib: 15863489 bytes
You can reproduce the size check with:
code_bytes = len(Path("train_gpt.py").read_text(encoding="utf-8").encode("utf-8"))
model_bytes = os.path.getsize("final_model.int8.ptz")
assert code_bytes + model_bytes < 16_000_000, f"Over limit: {code_bytes + model_bytes}"
The submission size check uses the on-disk file size of final_model.int8.ptz, not the size of the in-memory compressed bytes. Always verify with os.path.getsize() after writing the file.