Chapter 07: Model Selection and Evaluation

Contract: apr-book-ch07

Run: cargo run -p aprender-core --example ch07_model_selection

#![allow(clippy::disallowed_methods)]
//! Chapter 7: Model Selection and Evaluation
//!
//! Demonstrates KFold cross-validation with a real model + metrics.
//! Citation: Kohavi, "Cross-Validation and Bootstrap," IJCAI 1995
//! Contract: contracts/apr-book-ch07-v1.yaml (v2 — api_calls enforced)

use aprender::prelude::*;
use aprender::metrics::classification::{accuracy, precision, recall, Average};
use aprender::model_selection::KFold;

fn main() {
    // Dataset: two linearly separable classes
    let x = Matrix::from_vec(12, 2, vec![
        1.0, 1.5, 2.0, 2.5, 1.5, 2.0, 2.5, 1.0, 3.0, 2.0, 2.0, 1.5,
        7.0, 7.5, 8.0, 8.5, 7.5, 8.0, 8.5, 7.0, 9.0, 8.0, 7.0, 7.5,
    ]).expect("valid 12x2 matrix");
    let y: Vec<usize> = vec![0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1];

    // KFold cross-validation with 3 folds
    let kfold = KFold::new(3).with_shuffle(true).with_random_state(42);
    let splits = kfold.split(x.n_rows());

    println!("KFold cross-validation: k=3, {} samples", x.n_rows());
    assert_eq!(splits.len(), 3, "3-fold must produce 3 splits");

    let mut fold_accuracies = Vec::new();
    for (fold_idx, (train_idx, test_idx)) in splits.iter().enumerate() {
        // Extract train/test data by index
        let mut x_train_data = Vec::new();
        let mut y_train = Vec::new();
        for &i in train_idx {
            for col in 0..2 { x_train_data.push(x.get(i, col)); }
            y_train.push(y[i]);
        }
        let mut x_test_data = Vec::new();
        let mut y_test = Vec::new();
        for &i in test_idx {
            for col in 0..2 { x_test_data.push(x.get(i, col)); }
            y_test.push(y[i]);
        }

        let x_train = Matrix::from_vec(train_idx.len(), 2, x_train_data).unwrap();
        let x_test = Matrix::from_vec(test_idx.len(), 2, x_test_data).unwrap();

        // Train a decision tree on this fold
        let mut tree = DecisionTreeClassifier::new().with_max_depth(3);
        tree.fit(&x_train, &y_train).expect("fit on fold");
        let preds = tree.predict(&x_test);

        let acc = accuracy(&preds, &y_test);
        fold_accuracies.push(acc);
        println!("  Fold {}: train={} test={} accuracy={acc:.2}",
            fold_idx + 1, train_idx.len(), test_idx.len());
    }

    // Cross-validation mean accuracy
    let mean_acc: f32 = fold_accuracies.iter().sum::<f32>() / fold_accuracies.len() as f32;
    println!("\nMean accuracy: {mean_acc:.2}");
    assert!(mean_acc > 0.5, "Mean CV accuracy must exceed chance");

    // Full-data metrics
    let mut full_tree = DecisionTreeClassifier::new().with_max_depth(3);
    full_tree.fit(&x, &y).unwrap();
    let full_preds = full_tree.predict(&x);
    let prec = precision(&full_preds, &y, Average::Macro);
    let rec = recall(&full_preds, &y, Average::Macro);
    println!("Full-data precision={prec:.2}, recall={rec:.2}");
    assert!(prec > 0.0 && prec <= 1.0, "Precision in (0,1]");
    assert!(rec > 0.0 && rec <= 1.0, "Recall in (0,1]");

    println!("Chapter 7 contracts: PASSED");
}