Skip to main content

Overview

The Edge AI Hardware Optimization framework implements two complementary optimization techniques:
  1. Structured Channel Pruning: Removes entire convolutional channels to reduce parameters and compute while maintaining dense tensor operations
  2. Precision Quantization: Converts models to lower-precision formats (FP16, INT8) to reduce memory footprint and accelerate inference
Both techniques are implemented in src/edge_opt/pruning.py and src/edge_opt/quantization.py respectively.

Structured Channel Pruning

Algorithm Overview

Structured channel pruning removes whole output channels from convolutional layers based on L1-norm importance scores. Unlike unstructured weight pruning, this approach:
  • Produces dense tensors compatible with standard hardware accelerators
  • Reduces actual runtime memory and FLOPs (not just parameter count)
  • Requires no specialized sparse kernel support
  • Maintains straightforward deployment compatibility

Implementation: structured_channel_prune

The pruning implementation in src/edge_opt/pruning.py:14 removes channels from both convolutional layers while preserving connectivity:
def structured_channel_prune(model: SmallCNN, pruning_level: float) -> SmallCNN:
    """
    Remove channels from conv1 and conv2 based on L1-norm importance.
    
    Args:
        model: Trained SmallCNN instance
        pruning_level: Fraction of channels to remove [0.0, 1.0)
    
    Returns:
        New SmallCNN with reduced channel counts
    """
    if not 0.0 <= pruning_level < 1.0:
        raise ValueError("pruning_level must be in [0.0, 1.0).")
    
    # Step 1: Score channels by L1-norm of weights
    conv1_scores = model.conv1.weight.data.abs().sum(dim=(1, 2, 3))
    keep1 = _topk_indices(conv1_scores, pruning_level)
    
    conv2_scores = model.conv2.weight.data.abs().sum(dim=(1, 2, 3))
    keep2 = _topk_indices(conv2_scores, pruning_level)
    
    # Step 2: Create new model with reduced channels
    pruned = SmallCNN(
        conv1_channels=len(keep1), 
        conv2_channels=len(keep2)
    )
    
    # Step 3: Copy selected channels (see below)
    ...
The pruning level is a fraction of channels to remove, not keep. pruning_level=0.3 removes 30% of channels, keeping 70%.

Channel Selection Strategy

The _topk_indices helper function selects channels to keep based on importance scores:
def _topk_indices(channel_scores: torch.Tensor, pruning_level: float) -> torch.Tensor:
    total = channel_scores.numel()
    keep = max(1, int(round(total * (1.0 - pruning_level))))  # At least 1 channel
    return torch.topk(channel_scores, keep, largest=True).indices.sort().values
Key design choices:
  • L1-norm scoring: Sum of absolute weights across spatial dimensions (H, W) and input channels
  • Top-k selection: Keep channels with highest L1-norm scores
  • Minimum retention: Always keep at least 1 channel even with high pruning levels
  • Sorted indices: Maintain channel order for deterministic behavior

Weight Transfer Process

The pruned model is initialized with weights from selected channels:
1

Conv1 Layer Transfer

Copy weights and biases for selected conv1 output channels:
with torch.no_grad():
    # Conv1: Select output channels
    pruned.conv1.weight.copy_(model.conv1.weight[keep1])
    pruned.conv1.bias.copy_(model.conv1.bias[keep1])
Shape transformation: [C_out_original, C_in, H, W][C_out_pruned, C_in, H, W]
2

Conv2 Layer Transfer

Copy weights for selected conv2 output channels AND input channels (must match conv1 outputs):
# Conv2: Select both output channels (keep2) and input channels (keep1)
conv2_w = model.conv2.weight[keep2][:, keep1, :, :]
pruned.conv2.weight.copy_(conv2_w)
pruned.conv2.bias.copy_(model.conv2.bias[keep2])
Shape transformation: [C2_out_original, C1_out_original, H, W][C2_out_pruned, C1_out_pruned, H, W]
3

Classifier Layer Transfer

Remap fully-connected layer inputs to match flattened conv2 outputs:
# Compute spatial flattening indices
features_per_channel = 7 * 7  # After 2x MaxPool2d on 28x28 input
fc_indices = []
for channel in keep2.tolist():
    start = channel * features_per_channel
    fc_indices.extend(range(start, start + features_per_channel))
fc_idx = torch.tensor(fc_indices, dtype=torch.long)

# Select corresponding input features
pruned.classifier.weight.copy_(model.classifier.weight[:, fc_idx])
pruned.classifier.bias.copy_(model.classifier.bias)  # Output classes unchanged
Shape transformation: [num_classes, C2_out_original × 7 × 7][num_classes, C2_out_pruned × 7 × 7]

Usage Example

from edge_opt.model import SmallCNN
from edge_opt.pruning import structured_channel_prune

# Train baseline model
base_model = SmallCNN(conv1_channels=16, conv2_channels=32)
# ... training code ...

# Apply 50% channel pruning
pruned_model = structured_channel_prune(base_model, pruning_level=0.5)

print(f"Original conv1 channels: {base_model.conv1.out_channels}")  # 16
print(f"Pruned conv1 channels: {pruned_model.conv1.out_channels}")    # 8
print(f"Original conv2 channels: {base_model.conv2.out_channels}")  # 32
print(f"Pruned conv2 channels: {pruned_model.conv2.out_channels}")    # 16
Structured pruning is applied without retraining. The pruned model inherits weights from the trained baseline, so accuracy degradation depends on the importance of removed channels. Fine-tuning after pruning can recover some lost accuracy.

Precision Quantization

Supported Precision Modes

The framework supports three precision modes for model evaluation:

FP32

32-bit Floating PointBaseline precision with no conversion overhead. Used as reference for accuracy and performance comparisons.

FP16

16-bit Floating PointHalf-precision reduces memory by 50% with minimal accuracy impact. Requires hardware FP16 support for performance gains.

INT8

8-bit IntegerQuantized precision reduces memory by 75% with potential accuracy loss. Requires calibration data for activation range estimation.

FP16 Conversion: to_fp16

The FP16 conversion is straightforward using PyTorch’s .half() method:
def to_fp16(model: nn.Module) -> nn.Module:
    """
    Convert model to half-precision (FP16).
    
    Args:
        model: Trained model in FP32
    
    Returns:
        Deep copy of model in FP16 evaluation mode
    """
    fp16_model = deepcopy(model).half().eval()
    return fp16_model
Implementation details:
  • Uses deepcopy to avoid modifying the original model
  • Automatically converts all parameters and buffers to FP16
  • Sets model to evaluation mode (disables dropout, batch norm training mode)
  • Input tensors must also be converted to FP16 during inference

INT8 Quantization: to_int8

INT8 quantization uses PyTorch’s FX graph mode quantization with post-training static quantization (PTQ):
def to_int8(
    model: nn.Module, 
    calibration_loader: DataLoader, 
    calibration_batches: int = 10
) -> nn.Module:
    """
    Apply post-training static quantization (PTQ) to INT8.
    
    Args:
        model: Trained model in FP32
        calibration_loader: DataLoader for activation range estimation
        calibration_batches: Number of batches for calibration
    
    Returns:
        Quantized model with INT8 weights and activations
    """
    float_model = deepcopy(model).eval()
    
    # Configure quantization for x86 CPU (fbgemm backend)
    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    
    # Prepare model for quantization (insert observers)
    example_inputs, _ = next(iter(calibration_loader))
    prepared = prepare_fx(float_model, qconfig_mapping, example_inputs=(example_inputs,))
    
    # Calibration: Collect activation statistics
    with torch.no_grad():
        for index, (inputs, _) in enumerate(calibration_loader):
            _ = prepared(inputs)
            if index + 1 >= calibration_batches:
                break
    
    # Convert to quantized model
    quantized = convert_fx(prepared)
    return quantized

Quantization Workflow

1

Model Preparation

Create a deep copy in evaluation mode and configure the fbgemm quantization backend for x86 CPU targets.
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
For ARM targets, use "qnnpack" backend instead of "fbgemm".
2

Observer Insertion

Insert activation observers using FX graph mode preparation:
prepared = prepare_fx(float_model, qconfig_mapping, example_inputs=(example_inputs,))
Observers track min/max ranges during calibration to determine quantization scale and zero-point.
3

Calibration

Run inference on calibration data to collect activation statistics:
with torch.no_grad():
    for index, (inputs, _) in enumerate(calibration_loader):
        _ = prepared(inputs)  # Forward pass to update observers
        if index + 1 >= calibration_batches:
            break
Calibration quality directly affects quantized accuracy. Too few batches may bias activation ranges. The default calibration_batches=10 balances speed and quality.
4

Model Conversion

Convert the prepared model to quantized format:
quantized = convert_fx(prepared)
This replaces FP32 ops with INT8 equivalents and fuses operations (e.g., Conv + ReLU).

Precision Selection in Sweep

The experiments.precision_variant function handles precision conversion during sweeps:
def precision_variant(
    model: nn.Module, 
    precision: str, 
    calibration_loader: DataLoader, 
    calibration_batches: int
) -> tuple[nn.Module, str]:
    """
    Create precision variant and return metric precision hint.
    
    Returns:
        (converted_model, metric_precision)
        - metric_precision: "fp32" or "fp16" for input tensor conversion
    """
    if precision == "fp32":
        return deepcopy(model).eval(), "fp32"
    if precision == "fp16":
        return to_fp16(model), "fp16"
    if precision == "int8":
        return to_int8(model, calibration_loader, calibration_batches), "fp32"
    raise ValueError(f"Unsupported precision '{precision}'")
INT8 models use metric_precision="fp32" because input tensors remain in FP32 format. The quantized model internally converts inputs to INT8.

Trade-off Analysis

The hardware.precision_tradeoff_table function aggregates sweep results by precision mode:
def precision_tradeoff_table(sweep_df: pd.DataFrame) -> pd.DataFrame:
    grouped = sweep_df.groupby("precision", as_index=False).agg(
        accuracy_mean=("accuracy", "mean"),
        latency_ms_mean=("latency_ms", "mean"),
        memory_mb_mean=("memory_mb", "mean"),
        energy_proxy_j_mean=("energy_proxy_j", "mean"),
        accepted_ratio=("accepted", "mean"),  # Fraction passing memory budget
    )
    return grouped.sort_values("latency_ms_mean").reset_index(drop=True)
This produces precision_tradeoffs.csv with mean performance across all pruning levels for each precision mode.

Expected Trade-offs

  • FP32: Baseline accuracy (no degradation)
  • FP16: Minimal accuracy loss (<0.1% typical for CNNs)
  • INT8: 0.5-2% accuracy degradation depending on calibration quality
More aggressive pruning compounds quantization accuracy loss.
  • FP32: 4 bytes per parameter
  • FP16: 2 bytes per parameter (50% reduction)
  • INT8: 1 byte per parameter (75% reduction)
Pruning provides multiplicative memory savings (e.g., 50% pruning + INT8 = 87.5% total reduction).
  • FP32: Baseline latency
  • FP16: 1.5-2x speedup on hardware with FP16 SIMD support
  • INT8: 2-4x speedup on CPUs with AVX-512 VNNI or equivalent
Actual speedups depend on hardware support and operation fusion opportunities.
The accepted_ratio metric shows the fraction of configurations passing the active memory budget:
  • Tighter budgets favor INT8 and aggressive pruning
  • FP32 with low pruning typically fails strict memory constraints
  • FP16 provides middle ground between accuracy and memory

Best Practices

Pruning First, Then Quantize

Apply pruning before quantization to reduce calibration overhead and improve quantization quality on the reduced parameter space.

Calibration Data Quality

Use representative calibration data matching the deployment distribution. Poor calibration leads to clipped activations and accuracy degradation.

Validate Acceptance Ratio

Check precision_tradeoffs.csv acceptance ratios to ensure sufficient candidates pass memory budgets. Low ratios indicate budget constraints are too strict.

Monitor P95 Latency

Evaluate P95 latency distributions, not just mean latency. Tail latencies often dominate real-time deployment constraints.

Next Steps

Hardware Constraints

Learn about memory budgets and constraint filtering

System Architecture

Understand the full optimization pipeline

Build docs developers (and LLMs) love