Neural Network Pruning Theory
Neural network pruning is a model compression technique that removes redundant parameters to reduce model size and computational cost while maintaining accuracy.
Overview
Modern neural networks are often over-parameterized, containing many weights that contribute little to the final prediction. Pruning identifies and removes these less important weights, producing a sparse model.
Key Benefits
- Reduced memory footprint - Fewer parameters to store
- Faster inference - Less computation required
- Energy efficiency - Lower power consumption
- Edge deployment - Enables deployment on resource-constrained devices
Pruning Criteria
Magnitude-Based Pruning
The simplest and most effective pruning method uses weight magnitude as an importance metric.
L1 Magnitude (Absolute Value)
importance(w) = |w|
Weights with small absolute values contribute less to the output and can be safely removed.
L2 Magnitude (Squared Value)
importance(w) = w^2
Squared magnitude penalizes small weights more aggressively, creating a clearer separation between important and unimportant weights.
Activation-Weighted Pruning (Wanda)
Wanda (Weights AND Activations) considers both weight magnitude and activation statistics:
importance(w_ij) = |w_ij| * sqrt(sum(x_j^2) / n)
This captures how much each weight contributes to the output given typical inputs, requiring calibration data to estimate activation statistics.
Reference: Sun et al. (2023) - "A Simple and Effective Pruning Approach for Large Language Models"
Sparsity Patterns
Unstructured Sparsity
Individual weights are pruned independently, achieving maximum flexibility and compression but limited hardware acceleration.
Original: [0.5, 0.1, -0.8, 0.02]
Mask: [1, 0, 1, 0 ]
Pruned: [0.5, 0, -0.8, 0 ]
N:M Structured Sparsity
Exactly N non-zero values per M consecutive elements. Hardware-accelerated on NVIDIA Ampere+ GPUs.
Common patterns:
- 2:4 - 2 non-zeros per 4 elements (50% sparsity)
- 4:8 - 4 non-zeros per 8 elements (50% sparsity)
2:4 Pattern:
Original: [0.5, 0.1, -0.8, 0.02]
Mask: [1, 0, 1, 0 ] // 2 ones per 4 elements
Pruned: [0.5, 0, -0.8, 0 ]
Block Sparsity
Entire blocks of weights are pruned together, enabling efficient memory access patterns.
Pruning Schedules
One-Shot Pruning
Prune to target sparsity in a single step, typically after pre-training.
let schedule = PruningSchedule::OneShot { step: 1000 };
Gradual Pruning
Incrementally increase sparsity over training, allowing the model to adapt.
let schedule = PruningSchedule::Gradual {
start_step: 1000,
end_step: 5000,
initial_sparsity: 0.0,
final_sparsity: 0.5,
frequency: 500, // Prune every 500 steps
};
Cubic Pruning Schedule
The Zhu & Gupta (2017) cubic schedule provides smooth sparsity increase:
s_t = s_f * (1 - (1 - t/T)^3)
Where:
s_t= sparsity at step ts_f= final target sparsityT= total pruning steps
This schedule prunes aggressively early (when model is more plastic) and gradually slows.
Reference: Zhu & Gupta (2017) - "To Prune or Not To Prune"
Implementation in Aprender
Computing Importance Scores
use aprender::pruning::{MagnitudeImportance, Importance};
use aprender::nn::Linear;
let layer = Linear::new(768, 768);
// L1 magnitude
let l1 = MagnitudeImportance::l1();
let scores = l1.compute(&layer, None)?;
// L2 magnitude
let l2 = MagnitudeImportance::l2();
let scores = l2.compute(&layer, None)?;
Generating Sparsity Masks
use aprender::pruning::{generate_unstructured_mask, generate_nm_mask};
// 50% unstructured sparsity
let mask = generate_unstructured_mask(&scores.values, 0.5)?;
// 2:4 N:M sparsity
let nm_mask = generate_nm_mask(&scores.values, 2, 4)?;
Applying Masks
let mut weights = layer.weight().clone();
mask.apply(&mut weights)?;
// Verify sparsity
let actual_sparsity = mask.sparsity();
assert!((actual_sparsity - 0.5).abs() < 0.01);
Best Practices
- Start with magnitude pruning - Simple, effective, no calibration needed
- Use gradual schedules for high sparsity - Allows model adaptation
- Fine-tune after pruning - Recover accuracy lost during pruning
- Validate with representative data - Ensure pruned model generalizes
- Consider hardware targets - Use N:M patterns for GPU acceleration
Mathematical Properties
Importance Scores
- All importance scores are non-negative:
importance(w) >= 0 - Zero weights have zero importance:
importance(0) = 0 - Masks are idempotent:
apply(apply(w, m), m) = apply(w, m)
Sparsity Definition
sparsity = num_zeros / total_elements
For a 50% sparse model, half the weights are exactly zero.
References
- Han et al. (2015) - "Learning both Weights and Connections for Efficient Neural Networks"
- Zhu & Gupta (2017) - "To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression"
- Sun et al. (2023) - "A Simple and Effective Pruning Approach for Large Language Models"
- Frantar & Alistarh (2023) - "SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot"