Case Study: Lottery Ticket Pruning
This case study demonstrates finding winning tickets using the Lottery Ticket Hypothesis implementation in Aprender.
Overview
The Lottery Ticket Hypothesis (Frankle & Carbin, 2018) shows that dense networks contain sparse subnetworks that can train to full accuracy. We'll use Aprender's LotteryTicketPruner to find these winning tickets.
Finding a Winning Ticket
Basic Example
use aprender::pruning::{
LotteryTicketPruner, LotteryTicketConfig, RewindStrategy, Pruner
};
use aprender::nn::Linear;
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create a dense layer
let layer = Linear::new(256, 128);
// Configure lottery ticket search
// 90% sparsity over 10 iterative pruning rounds
let config = LotteryTicketConfig::new(0.9, 10)
.with_rewind_strategy(RewindStrategy::Init);
let pruner = LotteryTicketPruner::with_config(config);
// Find the winning ticket
let ticket = pruner.find_ticket(&layer)?;
println!("=== Winning Ticket Found ===");
println!("Total parameters: {}", ticket.total_parameters);
println!("Remaining parameters: {}", ticket.remaining_parameters);
println!("Sparsity: {:.2}%", ticket.sparsity * 100.0);
println!("Compression ratio: {:.1}x", ticket.compression_ratio());
println!("Density: {:.2}%", ticket.density() * 100.0);
Ok(())
}
Output:
=== Winning Ticket Found ===
Total parameters: 32768
Remaining parameters: 3277
Sparsity: 90.00%
Compression ratio: 10.0x
Density: 10.00%
Tracking Pruning Progress
Observing Iterative Pruning
use aprender::pruning::{LotteryTicketPruner, LotteryTicketConfig};
use aprender::nn::Linear;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let layer = Linear::new(100, 100);
let config = LotteryTicketConfig::new(0.95, 15);
let pruner = LotteryTicketPruner::with_config(config);
let ticket = pruner.find_ticket(&layer)?;
println!("Sparsity progression over {} rounds:", ticket.sparsity_history.len());
for (round, sparsity) in ticket.sparsity_history.iter().enumerate() {
let remaining = (1.0 - sparsity) * 100.0;
println!(" Round {:2}: {:.1}% sparse ({:.1}% remaining)",
round + 1, sparsity * 100.0, remaining);
}
Ok(())
}
Output:
Sparsity progression over 15 rounds:
Round 1: 18.1% sparse (81.9% remaining)
Round 2: 32.9% sparse (67.1% remaining)
Round 3: 45.1% sparse (54.9% remaining)
Round 4: 55.0% sparse (45.0% remaining)
Round 5: 63.2% sparse (36.8% remaining)
Round 6: 69.9% sparse (30.1% remaining)
Round 7: 75.4% sparse (24.6% remaining)
Round 8: 79.8% sparse (20.2% remaining)
Round 9: 83.5% sparse (16.5% remaining)
Round 10: 86.5% sparse (13.5% remaining)
Round 11: 88.9% sparse (11.1% remaining)
Round 12: 90.9% sparse (9.1% remaining)
Round 13: 92.6% sparse (7.4% remaining)
Round 14: 93.9% sparse (6.1% remaining)
Round 15: 95.0% sparse (5.0% remaining)
Using the Builder Pattern
Configuring All Options
use aprender::pruning::{
LotteryTicketPruner, LotteryTicketPrunerBuilder, RewindStrategy
};
use aprender::nn::Linear;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let layer = Linear::new(512, 256);
// Builder provides fluent configuration
let pruner = LotteryTicketPruner::builder()
.target_sparsity(0.8) // 80% sparsity target
.pruning_rounds(5) // 5 iterative rounds
.rewind_strategy(RewindStrategy::Early { iteration: 100 })
.global_pruning(true) // Prune globally across layers
.build();
let ticket = pruner.find_ticket(&layer)?;
println!("Configuration:");
println!(" Target sparsity: 80%");
println!(" Pruning rounds: 5");
println!(" Rewind strategy: Early (iteration 100)");
println!("\nResult:");
println!(" Achieved sparsity: {:.2}%", ticket.sparsity * 100.0);
Ok(())
}
Comparing Rewind Strategies
Init vs. Early vs. Late Rewinding
use aprender::pruning::{
LotteryTicketPruner, LotteryTicketConfig, RewindStrategy
};
use aprender::nn::Linear;
fn find_ticket_with_strategy(
layer: &Linear,
strategy: RewindStrategy,
name: &str
) -> Result<(), Box<dyn std::error::Error>> {
let config = LotteryTicketConfig::new(0.9, 10)
.with_rewind_strategy(strategy);
let pruner = LotteryTicketPruner::with_config(config);
let ticket = pruner.find_ticket(layer)?;
println!("{} Rewinding:", name);
println!(" Sparsity: {:.2}%", ticket.sparsity * 100.0);
println!(" Compression: {:.1}x\n", ticket.compression_ratio());
Ok(())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let layer = Linear::new(256, 256);
// Original LTH: rewind to initialization
find_ticket_with_strategy(&layer, RewindStrategy::Init, "Init")?;
// Early rewinding: rewind to early training checkpoint
find_ticket_with_strategy(
&layer,
RewindStrategy::Early { iteration: 100 },
"Early"
)?;
// Late rewinding: rewind to fraction of training
find_ticket_with_strategy(
&layer,
RewindStrategy::Late { fraction: 0.1 },
"Late"
)?;
// No rewinding: standard pruning
find_ticket_with_strategy(&layer, RewindStrategy::None, "None")?;
Ok(())
}
Applying Winning Tickets
Pruning Weights with Rewinding
use aprender::pruning::{LotteryTicketPruner, LotteryTicketConfig};
use aprender::nn::Linear;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let layer = Linear::new(64, 32);
let config = LotteryTicketConfig::new(0.75, 5);
let pruner = LotteryTicketPruner::with_config(config);
// Find winning ticket
let ticket = pruner.find_ticket(&layer)?;
// Apply ticket to get pruned weights with rewinding
let pruned_weights = pruner.apply_ticket(&ticket, &layer)?;
// Count zeros in pruned weights
let zeros = pruned_weights.iter().filter(|&&w| w == 0.0).count();
let total = pruned_weights.len();
let actual_sparsity = zeros as f32 / total as f32;
println!("Applied winning ticket:");
println!(" Total weights: {}", total);
println!(" Zero weights: {}", zeros);
println!(" Actual sparsity: {:.2}%", actual_sparsity * 100.0);
Ok(())
}
Using the Pruner Trait
Generic Pruning Interface
use aprender::pruning::{
Pruner, LotteryTicketPruner, LotteryTicketConfig, SparsityPattern
};
use aprender::nn::Linear;
fn prune_with_any_pruner<P: Pruner>(
pruner: &P,
module: &dyn aprender::nn::Module,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Pruner: {}", pruner.name());
// Compute importance scores
let scores = pruner.importance(module, None)?;
println!(" Importance range: [{:.4}, {:.4}]",
scores.stats.min, scores.stats.max);
// Generate mask at 50% sparsity
let mask = pruner.generate_mask(
module,
SparsityPattern::Unstructured,
0.5,
None
)?;
println!(" Mask sparsity: {:.2}%", mask.sparsity() * 100.0);
Ok(())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let layer = Linear::new(128, 64);
let config = LotteryTicketConfig::new(0.5, 3);
let pruner = LotteryTicketPruner::with_config(config);
prune_with_any_pruner(&pruner, &layer)?;
Ok(())
}
Output:
Pruner: LotteryTicket
Importance range: [0.0001, 0.9823]
Mask sparsity: 50.00%
High Sparsity Example
Finding Extremely Sparse Tickets
use aprender::pruning::{LotteryTicketPruner, LotteryTicketConfig};
use aprender::nn::Linear;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let layer = Linear::new(1024, 512);
// Target 99% sparsity (100x compression)
let config = LotteryTicketConfig::new(0.99, 20);
let pruner = LotteryTicketPruner::with_config(config);
let ticket = pruner.find_ticket(&layer)?;
println!("=== Extreme Sparsity Winning Ticket ===");
println!("Original parameters: {}", ticket.total_parameters);
println!("Remaining parameters: {}", ticket.remaining_parameters);
println!("Sparsity: {:.2}%", ticket.sparsity * 100.0);
println!("Compression: {:.0}x", ticket.compression_ratio());
println!("\nMemory savings:");
let original_mb = ticket.total_parameters as f32 * 4.0 / 1_000_000.0;
let pruned_mb = ticket.remaining_parameters as f32 * 4.0 / 1_000_000.0;
println!(" Original: {:.2} MB", original_mb);
println!(" Pruned: {:.3} MB", pruned_mb);
println!(" Saved: {:.2} MB ({:.1}%)",
original_mb - pruned_mb,
(1.0 - pruned_mb / original_mb) * 100.0);
Ok(())
}
Output:
=== Extreme Sparsity Winning Ticket ===
Original parameters: 524288
Remaining parameters: 5243
Sparsity: 99.00%
Compression: 100x
Memory savings:
Original: 2.10 MB
Pruned: 0.021 MB
Saved: 2.08 MB (99.0%)
Key Takeaways
- Iterative Pruning - LTH uses multiple prune-rewind cycles to find sparse subnetworks
- Rewind Strategies - Different rewinding points affect ticket quality
- Compression Ratios - 10-100x compression is achievable
- Pruner Trait -
LotteryTicketPrunerimplements the standardPrunerinterface - Builder Pattern - Fluent API for configuration
References
- Frankle, J., & Carbin, M. (2018). "The Lottery Ticket Hypothesis." ICLR 2019.
- Aprender Pruning Module:
src/pruning/lottery.rs