Why quantization is necessary
A bf16 parameter occupies 2 bytes. Even a small model with ~8 M parameters would require ~16 MB in raw bf16 — exactly at the limit with nothing left for the code. int8 cuts that to 1 byte per parameter, and zlib compression on top typically yields a further 2–3× reduction for structured weight tensors. The baseline numbers illustrate this:- Raw bf16 model: significantly over 16 MB
- After int8 + zlib (level 9): 15,815,847 bytes
- Code (
train_gpt.py): 47,642 bytes - Total: 15,863,489 bytes — under the 16,000,000-byte cap
The compression pipeline
Train in bf16/fp32
The model trains with weight matrices stored in fp32 (cast to bf16 at matmul time via
CastedLinear). Control tensors (resid_mix, attn_scale, mlp_scale, q_gain, skip_weights) are kept in fp32 throughout.Quantize state dict to int8
After training,
quantize_state_dict_int8() processes every tensor in the state dict. 2D float tensors get per-row int8 quantization; small tensors are kept in fp16; control tensors stay in fp32.Serialize with torch.save
The quantized dict is serialized into an in-memory
io.BytesIO buffer using torch.save.Compress with zlib level 9
The serialized bytes are compressed with
zlib.compress(quant_raw, level=9) and written to final_model.int8.ptz.Size check
os.path.getsize("final_model.int8.ptz") + len(code.encode("utf-8")) must be under 16,000,000.quantize_state_dict_int8()
The function applies three different treatments depending on the tensor:
Per-row int8 for 2D tensors
2D tensors (weight matrices) receive one scale per output row. This tracks per-channel magnitude variation much better than a single tensor-wide scale:INT8_PER_ROW_SCALE_DTYPE = torch.float16).
Outlier clipping
A high-percentile clip is applied before quantizing to suppress weight outliers without discarding most values:99.99984%, roughly 1 in 625,000 values is clipped per row — aggressive enough to remove extreme outliers while preserving nearly all weight information.
Control tensor treatment
The following tensor name patterns are never quantized — they are kept in fp32 regardless of size:restore_low_dim_params_to_fp32().
You can override the control tensor patterns via the
CONTROL_TENSOR_NAME_PATTERNS environment variable (comma-separated). INT8_KEEP_FLOAT_FP32_NAME_PATTERNS controls which patterns are kept in fp32 during quantization specifically (defaults to the same set).dequantize_state_dict_int8()
The roundtrip restores each tensor to its original dtype:
Artifact format
The filefinal_model.int8.ptz contains a zlib-compressed torch.save payload with the following top-level keys:
| Key | Contents |
|---|---|
__quant_format__ | Format identifier: "int8_clean_per_row_v1" |
quantized | Dict of int8 tensors (large float parameters) |
scales | Dict of fp16 row scales, one per quantized tensor |
dtypes | Original dtype string for each quantized tensor |
passthrough | Dict of non-quantized tensors (fp32 control params, fp16 small params, non-floats) |
qmeta | Per-tensor quantization scheme metadata (per_row) |
passthrough_orig_dtypes | Original dtype for passthrough tensors downcast from bf16/fp32 |
