Overview
The quantization module provides functions to convert neural network models to lower precision formats (FP16 and INT8) to reduce memory footprint and improve inference speed on edge devices.Functions
to_fp16
Converts a PyTorch model to 16-bit floating point (FP16) precision for reduced memory usage and faster inference.Parameters
The input PyTorch model to convert. Can be any
torch.nn.Module instance.Returns
A deep copy of the input model with all parameters and buffers converted to FP16 (half precision). The model is set to evaluation mode.Properties:
- All weights and biases are stored as
torch.float16 - Memory usage is approximately 50% of FP32
- Model is in evaluation mode (
.eval()called) - Independent copy - modifications don’t affect the original
Implementation Details
The function performs the following operations:- Deep Copy: Creates a complete copy of the model using
deepcopy(model) - Precision Conversion: Applies
.half()to convert all parameters to FP16 - Evaluation Mode: Sets the model to evaluation mode with
.eval()
The returned model requires FP16 inputs during inference. Convert input tensors using
inputs.half() before passing them to the model.Example
to_int8
Converts a PyTorch model to 8-bit integer (INT8) quantization using post-training static quantization with calibration data.Parameters
The input PyTorch model to quantize. Can be any
torch.nn.Module instance compatible with PyTorch’s FX quantization.A PyTorch DataLoader containing representative input data for calibration. The calibration process observes the range of activations to determine optimal quantization parameters.Requirements:
- Must yield batches in the format
(inputs, targets) - Should contain diverse, representative samples from the dataset
- At least
calibration_batchesbatches should be available
The number of batches from the calibration loader to use for calibration. More batches provide better quantization accuracy but take longer to process.Typical values:
10: Fast calibration, suitable for most cases50-100: Higher accuracy for critical applications1: Minimal calibration (not recommended)
Returns
A quantized version of the input model using INT8 weights and activations. The model uses PyTorch’s “fbgemm” backend optimized for x86 CPUs.Properties:
- Weights stored as INT8 (1 byte per parameter)
- Activations computed in INT8
- Memory usage is approximately 25% of FP32
- Significant speedup on supported hardware
- Calibration scales and zero-points embedded in the model
Implementation Details
The function implements post-training static quantization using PyTorch’s FX graph mode:-
Preparation: Creates a deep copy in evaluation mode and prepares for quantization
-
FX Preparation: Inserts observers to track activation ranges
-
Calibration: Runs calibration data through the model to collect statistics
-
Conversion: Converts the calibrated model to INT8
The “fbgemm” backend is optimized for x86 CPUs. For ARM devices, consider using “qnnpack” backend by modifying the
get_default_qconfig_mapping() call.