Skip to main content

Overview

The quantization module provides functions to convert neural network models to lower precision formats (FP16 and INT8) to reduce memory footprint and improve inference speed on edge devices.

Functions

to_fp16

Converts a PyTorch model to 16-bit floating point (FP16) precision for reduced memory usage and faster inference.
from edge_opt.quantization import to_fp16

fp16_model = to_fp16(model)

Parameters

model
nn.Module
required
The input PyTorch model to convert. Can be any torch.nn.Module instance.

Returns

fp16_model
nn.Module
A deep copy of the input model with all parameters and buffers converted to FP16 (half precision). The model is set to evaluation mode.Properties:
  • All weights and biases are stored as torch.float16
  • Memory usage is approximately 50% of FP32
  • Model is in evaluation mode (.eval() called)
  • Independent copy - modifications don’t affect the original

Implementation Details

The function performs the following operations:
  1. Deep Copy: Creates a complete copy of the model using deepcopy(model)
  2. Precision Conversion: Applies .half() to convert all parameters to FP16
  3. Evaluation Mode: Sets the model to evaluation mode with .eval()
fp16_model = deepcopy(model).half().eval()
The returned model requires FP16 inputs during inference. Convert input tensors using inputs.half() before passing them to the model.
FP16 models may experience numerical instability with very small or very large values due to reduced precision range. Monitor accuracy metrics when deploying FP16 models.

Example

import torch
from edge_opt.model import SmallCNN
from edge_opt.quantization import to_fp16

# Create and train a model
model = SmallCNN()
# ... training code ...

# Convert to FP16
fp16_model = to_fp16(model)

# Inference with FP16
inputs = torch.randn(1, 1, 28, 28).half()  # Convert inputs to FP16
with torch.no_grad():
    outputs = fp16_model(inputs)

# Check memory savings
original_memory = sum(p.numel() * p.element_size() for p in model.parameters())
fp16_memory = sum(p.numel() * p.element_size() for p in fp16_model.parameters())
print(f"Memory reduction: {(1 - fp16_memory/original_memory) * 100:.1f}%")

to_int8

Converts a PyTorch model to 8-bit integer (INT8) quantization using post-training static quantization with calibration data.
from edge_opt.quantization import to_int8

int8_model = to_int8(model, calibration_loader, calibration_batches=10)

Parameters

model
nn.Module
required
The input PyTorch model to quantize. Can be any torch.nn.Module instance compatible with PyTorch’s FX quantization.
calibration_loader
DataLoader
required
A PyTorch DataLoader containing representative input data for calibration. The calibration process observes the range of activations to determine optimal quantization parameters.Requirements:
  • Must yield batches in the format (inputs, targets)
  • Should contain diverse, representative samples from the dataset
  • At least calibration_batches batches should be available
calibration_batches
int
default:10
The number of batches from the calibration loader to use for calibration. More batches provide better quantization accuracy but take longer to process.Typical values:
  • 10: Fast calibration, suitable for most cases
  • 50-100: Higher accuracy for critical applications
  • 1: Minimal calibration (not recommended)

Returns

quantized_model
nn.Module
A quantized version of the input model using INT8 weights and activations. The model uses PyTorch’s “fbgemm” backend optimized for x86 CPUs.Properties:
  • Weights stored as INT8 (1 byte per parameter)
  • Activations computed in INT8
  • Memory usage is approximately 25% of FP32
  • Significant speedup on supported hardware
  • Calibration scales and zero-points embedded in the model

Implementation Details

The function implements post-training static quantization using PyTorch’s FX graph mode:
  1. Preparation: Creates a deep copy in evaluation mode and prepares for quantization
    float_model = deepcopy(model).eval()
    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    
  2. FX Preparation: Inserts observers to track activation ranges
    prepared = prepare_fx(float_model, qconfig_mapping, example_inputs=(example_inputs,))
    
  3. Calibration: Runs calibration data through the model to collect statistics
    for index, (inputs, _) in enumerate(calibration_loader):
        _ = prepared(inputs)
        if index + 1 >= calibration_batches:
            break
    
  4. Conversion: Converts the calibrated model to INT8
    quantized = convert_fx(prepared)
    
The “fbgemm” backend is optimized for x86 CPUs. For ARM devices, consider using “qnnpack” backend by modifying the get_default_qconfig_mapping() call.
Calibration is Critical: The quality of INT8 quantization heavily depends on calibration data. Ensure the calibration loader contains diverse, representative samples covering the full range of expected inputs.

Example

import torch
from torch.utils.data import DataLoader, TensorDataset
from edge_opt.model import SmallCNN
from edge_opt.quantization import to_int8

# Create model and data
model = SmallCNN()
# ... training code ...

# Prepare calibration data (representative subset)
calib_inputs = torch.randn(1000, 1, 28, 28)
calib_targets = torch.randint(0, 10, (1000,))
calib_dataset = TensorDataset(calib_inputs, calib_targets)
calib_loader = DataLoader(calib_dataset, batch_size=32)

# Quantize to INT8
int8_model = to_int8(
    model,
    calibration_loader=calib_loader,
    calibration_batches=20  # Use 20 batches for calibration
)

# Inference with INT8 (inputs remain FP32)
inputs = torch.randn(1, 1, 28, 28)
with torch.no_grad():
    outputs = int8_model(inputs)

# Measure memory savings
from edge_opt.metrics import model_memory_mb
original_mb = model_memory_mb(model)
quantized_mb = model_memory_mb(int8_model)
print(f"Memory: {original_mb:.2f} MB -> {quantized_mb:.2f} MB")
print(f"Reduction: {(1 - quantized_mb/original_mb) * 100:.1f}%")

Calibration Best Practices

# Use diverse samples from validation set
calib_size = 1000  # 1000 samples
calib_batches = 20  # 20 batches of 50 each

int8_model = to_int8(model, val_loader, calibration_batches=calib_batches)

Build docs developers (and LLMs) love