Overview
The pruning module provides structured channel pruning capabilities for convolutional neural networks. This technique removes entire channels from convolutional layers based on their importance scores, reducing model size while maintaining performance.Functions
structured_channel_prune
Performs structured channel pruning on a SmallCNN model by removing the least important channels from convolutional layers.Parameters
The input CNN model to prune. Must be an instance of the SmallCNN class with conv1, conv2, and classifier layers.
The fraction of channels to remove from each convolutional layer. Must be in the range [0.0, 1.0).
0.0: No pruning (keeps all channels)0.3: Removes 30% of channels0.5: Removes 50% of channels- Values >= 1.0: Invalid and will raise ValueError
Returns
A new SmallCNN instance with reduced channel counts. The returned model has:
- Reduced conv1 output channels based on importance scores
- Reduced conv2 output channels based on importance scores
- Adjusted classifier input dimensions to match pruned conv2 output
- Weights copied from the original model for retained channels
Raises
Implementation Details
The function uses L1-norm based importance scoring:-
Channel Importance Scoring: For each convolutional layer, channels are scored by computing the sum of absolute values across all weights in that channel:
-
Channel Selection: The top-k channels with highest scores are retained, where k =
int(round(total * (1.0 - pruning_level))) - Weight Copying: Weights from retained channels are copied to the new pruned model
- Dimension Adjustment: The classifier layer is adjusted to accept the reduced feature map size from pruned conv2 (features_per_channel = 7 × 7)
At least one channel is always retained per layer, even with high pruning levels. The function ensures
keep = max(1, int(round(total * (1.0 - pruning_level))))Example
_topk_indices
Helper function that selects the indices of the top-k highest scoring channels based on the pruning level.Parameters
A 1D tensor containing importance scores for each channel. Higher scores indicate more important channels that should be retained.
The fraction of channels to remove. Must be in the range [0.0, 1.0).
Returns
A 1D tensor of sorted indices corresponding to the channels to keep. The indices are:
- Selected based on highest scores
- Sorted in ascending order
- Guaranteed to have at least one element
Implementation Details
The function calculates the number of channels to keep:torch.topk to select the indices with the highest scores and sorts them for consistent ordering.
This is an internal helper function. Users should typically use
structured_channel_prune instead of calling this directly.