Skip to main content
Pruning is a critical optimization technique for edge AI deployment that removes redundant or less important neural network channels to reduce model size, memory footprint, and inference latency while preserving accuracy.

Overview

The Edge AI Hardware Optimization framework implements structured channel pruning, which removes entire convolutional channels based on their importance scores. This approach is more hardware-friendly than unstructured pruning because it maintains dense tensor operations that are well-supported by edge device accelerators.

How It Works

Structured channel pruning works by:
1

Calculate Channel Importance

Compute importance scores for each channel by summing the absolute values of weights across the channel dimensions.
2

Select Top Channels

Keep only the top-k most important channels based on the pruning level (e.g., keep 75% of channels for 0.25 pruning).
3

Create Pruned Model

Instantiate a new model with reduced channel dimensions and copy weights from the selected channels.
4

Propagate Changes

Update subsequent layers to match the new channel dimensions throughout the network.

Function Signature

The main pruning function is defined in src/edge_opt/pruning.py:
def structured_channel_prune(model: SmallCNN, pruning_level: float) -> SmallCNN:
    """Apply structured channel pruning to a CNN model.
    
    Args:
        model: The input SmallCNN model to prune
        pruning_level: Fraction of channels to remove (0.0 to 1.0)
        
    Returns:
        A new pruned SmallCNN with reduced channels
        
    Raises:
        ValueError: If pruning_level is not in [0.0, 1.0)
    """

Pruning Levels

The pruning_level parameter controls the aggressiveness of pruning:
pruning_level
float
required
Fraction of channels to remove from each convolutional layer. Must be in the range [0.0, 1.0).
  • 0.0: No pruning (baseline model)
  • 0.25: Remove 25% of channels (mild pruning)
  • 0.5: Remove 50% of channels (moderate pruning)
  • 0.7: Remove 70% of channels (aggressive pruning)
  • 0.9: Remove 90% of channels (extreme pruning)
Memory reduction: Approximately proportional to (1 - pruning_level)²Latency reduction: Approximately proportional to (1 - pruning_level)Accuracy impact: Generally minimal up to 0.5, noticeable degradation beyond 0.7

Basic Usage

import torch
from edge_opt.model import SmallCNN
from edge_opt.pruning import structured_channel_prune

# Load trained model
model = SmallCNN(conv1_channels=16, conv2_channels=32)
model.load_state_dict(torch.load('trained_model.pth'))

# Apply 50% pruning
pruned_model = structured_channel_prune(model, pruning_level=0.5)

# The pruned model now has:
# - conv1: 8 channels (50% of 16)
# - conv2: 16 channels (50% of 32)

print(f"Original model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Pruned model parameters: {sum(p.numel() for p in pruned_model.parameters()):,}")

Implementation Details

The pruning algorithm in src/edge_opt/pruning.py:14-45 works as follows:

Channel Importance Scoring

# Calculate importance scores for conv1
conv1_scores = model.conv1.weight.data.abs().sum(dim=(1, 2, 3))
This computes the L1 norm of each output channel by summing absolute weight values across:
  • Input channels (dim=1)
  • Kernel height (dim=2)
  • Kernel width (dim=3)
Channels with higher scores contribute more to the model’s output and are prioritized for retention.

Top-K Channel Selection

def _topk_indices(channel_scores: torch.Tensor, pruning_level: float) -> torch.Tensor:
    total = channel_scores.numel()
    keep = max(1, int(round(total * (1.0 - pruning_level))))
    return torch.topk(channel_scores, keep, largest=True).indices.sort().values
This helper function:
  1. Calculates how many channels to keep: keep = total × (1 - pruning_level)
  2. Selects the top-k highest-scoring channels
  3. Sorts indices to maintain channel order
  4. Ensures at least 1 channel is kept

Weight Transfer

with torch.no_grad():
    # Copy conv1 weights and biases
    pruned.conv1.weight.copy_(model.conv1.weight[keep1])
    pruned.conv1.bias.copy_(model.conv1.bias[keep1])
    
    # Copy conv2 weights (both input and output dimensions affected)
    conv2_w = model.conv2.weight[keep2][:, keep1, :, :]
    pruned.conv2.weight.copy_(conv2_w)
    pruned.conv2.bias.copy_(model.conv2.bias[keep2])
Weight transfer must account for:
  • Conv1 output channels: Only kept channels are copied
  • Conv2 input channels: Must match conv1 output (uses keep1)
  • Conv2 output channels: Only kept channels are copied (uses keep2)

Fully Connected Layer Adjustment

# Map conv2 channels to flattened FC input indices
features_per_channel = 7 * 7  # After two 2x2 pooling layers
fc_indices = []
for channel in keep2.tolist():
    start = channel * features_per_channel
    fc_indices.extend(range(start, start + features_per_channel))
fc_idx = torch.tensor(fc_indices, dtype=torch.long)

pruned.classifier.weight.copy_(model.classifier.weight[:, fc_idx])
The fully connected layer receives flattened features from conv2. When conv2 channels are pruned, the FC layer’s input dimension must be adjusted to match the new feature map size.

Performance Impact

Pruning reduces model memory quadratically in many cases because it affects both the pruned layer and subsequent layers:Example: SmallCNN with 0.5 pruning
  • Conv1: 16 → 8 channels
  • Conv2 weights: (32, 16, 3, 3) → (16, 8, 3, 3)
  • Reduction: ~4× fewer parameters in conv2
Typical memory reductions:
  • 0.25 pruning: ~40-50% reduction
  • 0.5 pruning: ~60-70% reduction
  • 0.7 pruning: ~80-85% reduction
Latency improvements are roughly proportional to the compute reduction:MACs (Multiply-Accumulate Operations)
  • 0.25 pruning: ~44% fewer MACs
  • 0.5 pruning: ~75% fewer MACs
  • 0.7 pruning: ~91% fewer MACs
Measured latency reduction (device-dependent):
  • Raspberry Pi: 30-70% faster
  • Mobile CPU: 25-60% faster
  • GPU (less benefit): 10-40% faster
Accuracy degradation depends on model capacity and training:Fashion-MNIST SmallCNN baseline: ~89% accuracy
  • 0.0 pruning: 89.0% (baseline)
  • 0.25 pruning: 88.5% (-0.5%)
  • 0.5 pruning: 87.2% (-1.8%)
  • 0.7 pruning: 84.1% (-4.9%)
  • 0.9 pruning: 76.3% (-12.7%)
Recommendation: Stay below 0.7 for production deployments

Validation and Error Handling

The pruning function validates the pruning_level parameter and raises a ValueError if the value is outside the valid range [0.0, 1.0).
from edge_opt.pruning import structured_channel_prune

# This will raise ValueError
try:
    pruned = structured_channel_prune(model, pruning_level=1.0)
except ValueError as e:
    print(e)  # "pruning_level must be in [0.0, 1.0)."

# This will raise ValueError
try:
    pruned = structured_channel_prune(model, pruning_level=-0.1)
except ValueError as e:
    print(e)  # "pruning_level must be in [0.0, 1.0)."

Best Practices

Always prune after training: Prune a fully trained model rather than training a pruned architecture from scratch. This preserves the learned feature representations.
Start conservative: Begin with low pruning levels (0.25-0.5) and gradually increase while monitoring accuracy on your validation set.
Combine with quantization: Pruning and quantization are complementary techniques. Apply pruning first, then quantize the pruned model for maximum compression.

Complete Example

import torch
import yaml
from pathlib import Path
from edge_opt.model import SmallCNN
from edge_opt.pruning import structured_channel_prune
from edge_opt.metrics import model_memory_mb, collect_metrics
from torch.utils.data import DataLoader

# Load configuration
with open('configs/default.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Load trained model
device = torch.device('cpu')
model = SmallCNN().to(device)
model.load_state_dict(torch.load('outputs/trained_model.pth'))

# Create output directory
output_dir = Path(config['output_dir']) / 'pruned'
output_dir.mkdir(parents=True, exist_ok=True)

# Evaluate each pruning level
for pruning_level in config['pruning_levels']:
    print(f"\nEvaluating pruning level: {pruning_level}")
    
    # Apply pruning
    pruned_model = structured_channel_prune(model, pruning_level=pruning_level)
    pruned_model = pruned_model.to(device)
    
    # Collect metrics
    metrics = collect_metrics(
        model=pruned_model,
        loader=val_loader,
        device=device,
        power_watts=config['power_watts'],
        precision='fp32',
        benchmark_repeats=config['benchmark_repeats']
    )
    
    print(f"  Memory: {metrics.memory_mb:.2f} MB")
    print(f"  Accuracy: {metrics.accuracy:.2%}")
    print(f"  Latency: {metrics.latency_ms:.2f} ms")
    print(f"  Throughput: {metrics.throughput_sps:.1f} samples/sec")
    
    # Save pruned model
    save_path = output_dir / f'model_pruned_{pruning_level}.pth'
    torch.save(pruned_model.state_dict(), save_path)
    print(f"  Saved to: {save_path}")

Next Steps

After pruning your model:
  1. Apply quantization - Further reduce memory and latency with FP16 or INT8 precision
  2. Benchmark performance - Measure real-world latency on target hardware
  3. Fine-tune if needed - Retrain the pruned model briefly to recover any accuracy loss
  4. Deploy to edge - Export and deploy the optimized model to your target device
See the Quantization guide and Benchmarking guide for more details.

Build docs developers (and LLMs) love