Pruning-Aware Distillation
Status: Verified | Idempotent: Yes | Coverage: 95%+
Run Command
cargo run --example prune_magnitude
Code
//! # Recipe: Magnitude-Based Unstructured Pruning
//!
//! **Category**: optimize
//! **CLI Equivalent**: `apr prune --method magnitude --target 0.5`
//! Contract: contracts/recipe-iiur-v1.yaml
//!
//! Demonstrates magnitude-based unstructured pruning: zeroing out the
//! smallest-magnitude weights to achieve a target sparsity. This is the
//! simplest and most widely used pruning strategy.
//!
//! ## QA Checklist
//! 1. [x] `cargo run` succeeds (Exit Code 0)
//! 2. [x] `cargo test` passes
//! 3. [x] Deterministic output (Verified)
//! 4. [x] No temp files leaked
//! 5. [x] Clippy clean
//! 6. [x] No `unwrap()` in logic
//!
//!
//! ## Format Variants
//! ```bash
//! apr prune model.apr # APR native format
//! apr prune model.gguf # GGUF (llama.cpp compatible)
//! apr prune model.safetensors # SafeTensors (HuggingFace)
//! ```
//! ## References
//! - Frantar, E. & Alistarh, D. (2023). *SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot*. ICML. arXiv:2301.00774
use apr_cookbook::prelude::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
/// Generate deterministic weights using a hash-based PRNG.
fn det_weights(seed: u64, count: usize) -> Vec<f32> {
(0..count)
.map(|i| {
let mut h = DefaultHasher::new();
(seed, i as u64).hash(&mut h);
let bits = h.finish();
// Map to [-1.0, 1.0] range with roughly normal-ish distribution
let u = (bits & 0xFFFF_FFFF) as f64 / f64::from(u32::MAX);
let v = ((bits >> 32) & 0xFFFF_FFFF) as f64 / f64::from(u32::MAX);
// Box-Muller approximation via simple mapping
let centered = (u - 0.5) * 2.0;
let scaled = centered * (1.0 + v * 0.3);
scaled as f32
})
.collect()
}
/// Prune weights by magnitude: zero out the smallest weights to reach target sparsity.
///
/// Returns a new weight vector with the smallest-magnitude weights set to zero.
fn prune_magnitude(weights: &[f32], target_sparsity: f64) -> Vec<f32> {
if weights.is_empty() {
return Vec::new();
}
let num_to_prune = (weights.len() as f64 * target_sparsity).round() as usize;
let num_to_prune = num_to_prune.min(weights.len());
// Sort indices by absolute magnitude (ascending)
let mut indices_by_mag: Vec<usize> = (0..weights.len()).collect();
indices_by_mag.sort_by(|&a, &b| {
weights[a]
.abs()
.partial_cmp(&weights[b].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut pruned = weights.to_vec();
for &idx in indices_by_mag.iter().take(num_to_prune) {
pruned[idx] = 0.0;
}
pruned
}
/// Compute sparsity (fraction of zero weights).
fn compute_sparsity(weights: &[f32]) -> f64 {
if weights.is_empty() {
return 0.0;
}
let zeros = weights.iter().filter(|&&w| w == 0.0).count();
zeros as f64 / weights.len() as f64
}
/// Compute RMSE between original and pruned weights.
fn compute_rmse(original: &[f32], pruned: &[f32]) -> f64 {
assert_eq!(original.len(), pruned.len());
if original.is_empty() {
return 0.0;
}
let mse: f64 = original
.iter()
.zip(pruned.iter())
.map(|(a, b)| {
let diff = f64::from(*a) - f64::from(*b);
diff * diff
})
.sum::<f64>()
/ original.len() as f64;
mse.sqrt()
}
/// Render a simple ASCII histogram of weight magnitudes.
fn weight_histogram(weights: &[f32], bins: usize, label: &str) {
let max_mag = weights.iter().map(|w| w.abs()).fold(0.0_f32, f32::max);
if max_mag == 0.0 {
println!(" [{label}] All weights are zero");
return;
}
let bin_width = f64::from(max_mag) / bins as f64;
let mut counts = vec![0usize; bins];
for &w in weights {
let mag = f64::from(w.abs());
let bin = ((mag / bin_width) as usize).min(bins - 1);
counts[bin] += 1;
}
let max_count = counts.iter().copied().max().unwrap_or(1);
let bar_max = 40;
println!(
" [{label}] Weight magnitude histogram ({} weights):",
weights.len()
);
for (i, &count) in counts.iter().enumerate() {
let lo = i as f64 * bin_width;
let hi = (i + 1) as f64 * bin_width;
let bar_len = (count * bar_max).checked_div(max_count).unwrap_or(0);
let bar: String = "#".repeat(bar_len);
println!(" [{lo:>5.2}, {hi:>5.2}) | {bar:<40} ({count})");
}
}
fn main() -> Result<()> {
let mut ctx = RecipeContext::new("prune_magnitude")?;
// --- Section 1: Weight Distribution ---
println!("=== Magnitude-Based Unstructured Pruning ===\n");
println!("--- Weight Distribution ---");
let weights = det_weights(42, 1024);
let mean: f64 = weights.iter().map(|w| f64::from(*w)).sum::<f64>() / weights.len() as f64;
let variance: f64 = weights
.iter()
.map(|w| {
let d = f64::from(*w) - mean;
d * d
})
.sum::<f64>()
/ weights.len() as f64;
let std_dev = variance.sqrt();
println!(" Weights: {} parameters", weights.len());
println!(" Mean: {mean:.4}");
println!(" Std Dev: {std_dev:.4}");
println!(
" Min: {:.4}",
weights.iter().copied().fold(f32::INFINITY, f32::min)
);
println!(
" Max: {:.4}",
weights.iter().copied().fold(f32::NEG_INFINITY, f32::max)
);
println!();
weight_histogram(&weights, 8, "Original");
println!();
ctx.record_metric("weight_count", weights.len() as i64);
// --- Section 2: Pruning at Multiple Sparsities ---
println!("--- Pruning at Multiple Sparsities ---");
let sparsities = [0.1, 0.3, 0.5, 0.7, 0.9];
for &target in &sparsities {
let pruned = prune_magnitude(&weights, target);
let actual = compute_sparsity(&pruned);
let nonzero = pruned.iter().filter(|&&w| w != 0.0).count();
println!(
" Target: {:.0}% | Actual: {:.1}% | Non-zero: {}/{} | Zeros: {}",
target * 100.0,
actual * 100.0,
nonzero,
pruned.len(),
pruned.len() - nonzero
);
}
println!();
// --- Section 3: Histogram After 50% Pruning ---
println!("--- After 50% Pruning ---");
let pruned_50 = prune_magnitude(&weights, 0.5);
weight_histogram(&pruned_50, 8, "Pruned@50%");
println!();
// --- Section 4: Quality Impact (RMSE) ---
println!("--- Quality Impact (RMSE from Original) ---");
for &target in &sparsities {
let pruned = prune_magnitude(&weights, target);
let rmse = compute_rmse(&weights, &pruned);
let bar_len = (rmse * 50.0).round() as usize;
let bar: String = "|".repeat(bar_len.min(50));
println!(" Sparsity {:.0}%: RMSE = {rmse:.6} {bar}", target * 100.0);
let metric_name = format!("rmse_at_{}", (target * 100.0) as i64);
ctx.record_float_metric(&metric_name, rmse);
}
println!();
// --- Section 5: Save Pruned Model to APR v2 ---
println!("--- Save Pruned Model (APR v2) ---");
let pruned_final = prune_magnitude(&weights, 0.5);
let final_sparsity = compute_sparsity(&pruned_final);
let weight_bytes: Vec<u8> = pruned_final.iter().flat_map(|f| f.to_le_bytes()).collect();
let bundle = ModelBundleV2::new()
.with_name("pruned_magnitude_50")
.with_compression(Compression::Lz4)
.with_quantization(Quantization::FP32)
.add_tensor("pruned_weights", vec![1, pruned_final.len()], weight_bytes)
.build();
assert_eq!(&bundle[0..4], b"APR2");
println!(" Bundle size: {} bytes", bundle.len());
println!(" Format: APR v2 (LZ4 compressed)");
println!(" Final sparsity: {:.1}%", final_sparsity * 100.0);
println!(" Compression advantage: sparse tensors compress well with LZ4");
ctx.record_metric("bundle_size_bytes", bundle.len() as i64);
ctx.record_float_metric("final_sparsity", final_sparsity);
ctx.report()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_preserves_length() {
let weights = det_weights(1, 256);
let pruned = prune_magnitude(&weights, 0.5);
assert_eq!(weights.len(), pruned.len());
}
#[test]
fn test_achieves_target_sparsity_50() {
let weights = det_weights(2, 1000);
let pruned = prune_magnitude(&weights, 0.5);
let actual = compute_sparsity(&pruned);
assert!(
(actual - 0.5).abs() < 0.01,
"Expected ~50% sparsity, got {actual}"
);
}
#[test]
fn test_achieves_target_sparsity_90() {
let weights = det_weights(3, 1000);
let pruned = prune_magnitude(&weights, 0.9);
let actual = compute_sparsity(&pruned);
assert!(
(actual - 0.9).abs() < 0.01,
"Expected ~90% sparsity, got {actual}"
);
}
#[test]
fn test_zero_sparsity_preserves_all() {
let weights = det_weights(4, 128);
let pruned = prune_magnitude(&weights, 0.0);
assert_eq!(weights, pruned);
}
#[test]
fn test_full_sparsity_zeros_all() {
let weights = det_weights(5, 128);
let pruned = prune_magnitude(&weights, 1.0);
assert!(pruned.iter().all(|&w| w == 0.0));
}
#[test]
fn test_zero_weights_are_smallest() {
let weights = det_weights(6, 256);
let pruned = prune_magnitude(&weights, 0.5);
// All surviving (non-zero) weights should have magnitude >= all pruned weights
let surviving_min = pruned
.iter()
.filter(|&&w| w != 0.0)
.map(|w| w.abs())
.fold(f32::INFINITY, f32::min);
for (i, (&orig, &pr)) in weights.iter().zip(pruned.iter()).enumerate() {
if pr == 0.0 {
assert!(
orig.abs() <= surviving_min + f32::EPSILON,
"Pruned weight at {i} had magnitude {} > surviving min {surviving_min}",
orig.abs()
);
}
}
}
#[test]
fn test_rmse_increases_with_sparsity() {
let weights = det_weights(7, 512);
let rmses: Vec<f64> = [0.1, 0.3, 0.5, 0.7, 0.9]
.iter()
.map(|&s| {
let pruned = prune_magnitude(&weights, s);
compute_rmse(&weights, &pruned)
})
.collect();
for window in rmses.windows(2) {
assert!(
window[1] >= window[0],
"RMSE should increase: {} vs {}",
window[0],
window[1]
);
}
}
#[test]
fn test_rmse_zero_at_no_pruning() {
let weights = det_weights(8, 256);
let pruned = prune_magnitude(&weights, 0.0);
let rmse = compute_rmse(&weights, &pruned);
assert!(rmse < f64::EPSILON, "RMSE should be 0 with no pruning");
}
#[test]
fn test_deterministic_output() {
let w1 = det_weights(99, 512);
let w2 = det_weights(99, 512);
assert_eq!(w1, w2, "det_weights must be deterministic");
let p1 = prune_magnitude(&w1, 0.5);
let p2 = prune_magnitude(&w2, 0.5);
assert_eq!(p1, p2, "prune_magnitude must be deterministic");
}
#[test]
fn test_empty_weights() {
let pruned = prune_magnitude(&[], 0.5);
assert!(pruned.is_empty());
}
#[test]
fn test_sparsity_computation() {
let weights = vec![0.0, 1.0, 0.0, 2.0, 0.0];
let sparsity = compute_sparsity(&weights);
assert!((sparsity - 0.6).abs() < f64::EPSILON);
}
#[test]
fn test_apr_v2_bundle_valid() {
let weights = det_weights(10, 64);
let pruned = prune_magnitude(&weights, 0.5);
let bytes: Vec<u8> = pruned.iter().flat_map(|f| f.to_le_bytes()).collect();
let bundle = ModelBundleV2::new()
.with_name("test_pruned")
.with_compression(Compression::Lz4)
.with_quantization(Quantization::FP32)
.add_tensor("w", vec![1, pruned.len()], bytes)
.build();
assert_eq!(&bundle[0..4], b"APR2");
}
}