Skip to main content

Overview

The pruning module provides structured channel pruning capabilities for convolutional neural networks. This technique removes entire channels from convolutional layers based on their importance scores, reducing model size while maintaining performance.

Functions

structured_channel_prune

Performs structured channel pruning on a SmallCNN model by removing the least important channels from convolutional layers.
from edge_opt.pruning import structured_channel_prune

pruned_model = structured_channel_prune(model, pruning_level=0.3)

Parameters

model
SmallCNN
required
The input CNN model to prune. Must be an instance of the SmallCNN class with conv1, conv2, and classifier layers.
pruning_level
float
required
The fraction of channels to remove from each convolutional layer. Must be in the range [0.0, 1.0).
  • 0.0: No pruning (keeps all channels)
  • 0.3: Removes 30% of channels
  • 0.5: Removes 50% of channels
  • Values >= 1.0: Invalid and will raise ValueError

Returns

pruned_model
SmallCNN
A new SmallCNN instance with reduced channel counts. The returned model has:
  • Reduced conv1 output channels based on importance scores
  • Reduced conv2 output channels based on importance scores
  • Adjusted classifier input dimensions to match pruned conv2 output
  • Weights copied from the original model for retained channels

Raises

ValueError: Raised when pruning_level is not in the range [0.0, 1.0)

Implementation Details

The function uses L1-norm based importance scoring:
  1. Channel Importance Scoring: For each convolutional layer, channels are scored by computing the sum of absolute values across all weights in that channel:
    conv1_scores = model.conv1.weight.data.abs().sum(dim=(1, 2, 3))
    
  2. Channel Selection: The top-k channels with highest scores are retained, where k = int(round(total * (1.0 - pruning_level)))
  3. Weight Copying: Weights from retained channels are copied to the new pruned model
  4. Dimension Adjustment: The classifier layer is adjusted to accept the reduced feature map size from pruned conv2 (features_per_channel = 7 × 7)
At least one channel is always retained per layer, even with high pruning levels. The function ensures keep = max(1, int(round(total * (1.0 - pruning_level))))

Example

from edge_opt.model import SmallCNN
from edge_opt.pruning import structured_channel_prune

# Create base model with default channels (16, 32)
model = SmallCNN(conv1_channels=16, conv2_channels=32)

# Prune 40% of channels
pruned = structured_channel_prune(model, pruning_level=0.4)

# Resulting model will have approximately:
# - conv1: 16 * (1 - 0.4) = 9-10 channels
# - conv2: 32 * (1 - 0.4) = 19-20 channels

_topk_indices

Helper function that selects the indices of the top-k highest scoring channels based on the pruning level.
from edge_opt.pruning import _topk_indices

indices = _topk_indices(channel_scores, pruning_level=0.3)

Parameters

channel_scores
torch.Tensor
required
A 1D tensor containing importance scores for each channel. Higher scores indicate more important channels that should be retained.
pruning_level
float
required
The fraction of channels to remove. Must be in the range [0.0, 1.0).

Returns

indices
torch.Tensor
A 1D tensor of sorted indices corresponding to the channels to keep. The indices are:
  • Selected based on highest scores
  • Sorted in ascending order
  • Guaranteed to have at least one element

Implementation Details

The function calculates the number of channels to keep:
total = channel_scores.numel()
keep = max(1, int(round(total * (1.0 - pruning_level))))
Then uses torch.topk to select the indices with the highest scores and sorts them for consistent ordering.
This is an internal helper function. Users should typically use structured_channel_prune instead of calling this directly.

Example

import torch
from edge_opt.pruning import _topk_indices

# Channel importance scores
scores = torch.tensor([0.5, 1.2, 0.3, 2.1, 0.8])

# Keep top 60% of channels (prune 40%)
keep_indices = _topk_indices(scores, pruning_level=0.4)
# Returns: tensor([1, 3, 4]) - the indices of the 3 highest scores

Build docs developers (and LLMs) love