Model Family Identification
CLI Equivalent: apr oracle model.apr
What This Demonstrates
Identifies model architecture family (Transformer, CNN, RNN, MLP) from weight tensor names and shapes using heuristic pattern matching. Scores confidence by counting pattern hits across tensor naming conventions (e.g., attn, q_proj for Transformer; conv, bn for CNN) and reports evidence for each classification signal.
Run
cargo run --example analysis_oracle
Key APIs
identify_family(&tensor_names, &shapes)-- classify model into Transformer/CNN/RNN/MLP/Unknown with confidencescore_family(&tensor_names, &shapes, patterns)-- count pattern matches and compute confidence scoreTRANSFORMER_PATTERNS/CNN_PATTERNS/RNN_PATTERNS/MLP_PATTERNS-- heuristic pattern listsOracleResult { family, confidence, evidence }-- classification result with evidence accumulation
Code
#![allow(unused_imports)]
//! # Model Family Oracle
//! **CLI Equivalent**: `apr oracle`
//! Contract: contracts/recipe-iiur-v1.yaml
//!
//! Identifies model architecture family from weight tensor names and shapes.
//!
//! ## CLI equivalent
//! ```bash
//! apr oracle model.apr
//! ```
//!
//! ## What this demonstrates
//! - Heuristic classification of model architectures
//! - Pattern matching on tensor naming conventions
//! - Confidence scoring with evidence accumulation
//!
//!
//! ## Format Variants
//! ```bash
//! apr oracle model.apr # APR native format
//! apr oracle model.gguf # GGUF (llama.cpp compatible)
//! apr oracle model.safetensors # SafeTensors (HuggingFace)
//! ```
//! ## References
//! - Paleyes, A. et al. (2022). *Challenges in Deploying Machine Learning*. ACM Computing Surveys. DOI: 10.1145/3533378
use apr_cookbook::prelude::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
mod types;
#[allow(unused_imports)]
#[allow(clippy::wildcard_imports)]
use types::*;
fn main() -> Result<()> {
let ctx = RecipeContext::new("analysis_oracle")?;
// ── Section 1: Build a synthetic transformer model ──────────────────
println!("=== Model Family Oracle ===\n");
let tensor_names: Vec<String> = vec![
"model.embed_tokens.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.v_proj.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.up_proj.weight",
"model.layers.0.mlp.down_proj.weight",
"model.norm.weight",
"lm_head.weight",
]
.into_iter()
.map(String::from)
.collect();
let shapes: Vec<(String, Vec<usize>)> = vec![
("model.embed_tokens.weight".into(), vec![32000, 768]),
(
"model.layers.0.self_attn.q_proj.weight".into(),
vec![768, 768],
),
(
"model.layers.0.self_attn.k_proj.weight".into(),
vec![768, 768],
),
(
"model.layers.0.self_attn.v_proj.weight".into(),
vec![768, 768],
),
(
"model.layers.0.self_attn.o_proj.weight".into(),
vec![768, 768],
),
(
"model.layers.0.mlp.gate_proj.weight".into(),
vec![768, 3072],
),
("model.layers.0.mlp.up_proj.weight".into(), vec![768, 3072]),
(
"model.layers.0.mlp.down_proj.weight".into(),
vec![3072, 768],
),
("model.norm.weight".into(), vec![768]),
("lm_head.weight".into(), vec![32000, 768]),
];
// ── Section 2: Tensor name analysis ─────────────────────────────────
println!("--- Tensor Name Analysis ---");
println!("Model contains {} tensors:", tensor_names.len());
for name in &tensor_names {
println!(" {}", name);
}
println!();
// ── Section 3: Shape pattern matching ───────────────────────────────
println!("--- Shape Pattern Matching ---");
for (name, shape) in &shapes {
let shape_str: Vec<String> = shape.iter().map(ToString::to_string).collect();
println!(" {} : [{}]", name, shape_str.join(", "));
}
println!();
// ── Section 4: Confidence scoring ───────────────────────────────────
let result = identify_family(&tensor_names, &shapes);
println!("--- Confidence Scoring ---");
println!("Confidence: {:.1}%", result.confidence * 100.0);
println!("Evidence ({} signals):", result.evidence.len());
for ev in &result.evidence {
println!(" - {}", ev);
}
println!();
// ── Section 5: Family identification ────────────────────────────────
println!("--- Family Identification ---");
println!("Detected family: {}", result.family);
println!("Confidence: {:.1}%", result.confidence * 100.0);
// ── Section 6: Demonstrate with a CNN model ─────────────────────────
println!("\n--- CNN Model Test ---");
let cnn_names: Vec<String> = vec![
"backbone.conv1.weight",
"backbone.conv1.bias",
"backbone.bn1.weight",
"backbone.conv2.weight",
"backbone.pool.weight",
"head.fc.weight",
]
.into_iter()
.map(String::from)
.collect();
let cnn_shapes: Vec<(String, Vec<usize>)> = vec![
("backbone.conv1.weight".into(), vec![64, 3, 7, 7]),
("backbone.conv1.bias".into(), vec![64]),
("backbone.bn1.weight".into(), vec![64]),
("backbone.conv2.weight".into(), vec![128, 64, 3, 3]),
("backbone.pool.weight".into(), vec![128]),
("head.fc.weight".into(), vec![1000, 128]),
];
let cnn_result = identify_family(&cnn_names, &cnn_shapes);
println!(
"Detected family: {} ({:.1}%)",
cnn_result.family,
cnn_result.confidence * 100.0
);
// Use hash to demonstrate determinism
let mut hasher = DefaultHasher::new();
result.family.to_string().hash(&mut hasher);
result.evidence.len().hash(&mut hasher);
println!("\nOracle fingerprint: {:016x}", hasher.finish());
ctx.report()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn names(raw: &[&str]) -> Vec<String> {
raw.iter().map(|s| s.to_string()).collect()
}
fn shapes_from(raw: &[(&str, Vec<usize>)]) -> Vec<(String, Vec<usize>)> {
raw.iter()
.map(|(n, s)| (n.to_string(), s.clone()))
.collect()
}
#[test]
fn test_transformer_detected() {
let n = names(&[
"layers.0.self_attn.q_proj.weight",
"layers.0.self_attn.k_proj.weight",
"layers.0.self_attn.v_proj.weight",
"layers.0.mlp.gate.weight",
"embed_tokens.weight",
]);
let s = shapes_from(&[
("layers.0.self_attn.q_proj.weight", vec![768, 768]),
("layers.0.self_attn.k_proj.weight", vec![768, 768]),
("layers.0.self_attn.v_proj.weight", vec![768, 768]),
("layers.0.mlp.gate.weight", vec![768, 3072]),
("embed_tokens.weight", vec![32000, 768]),
]);
let result = identify_family(&n, &s);
assert_eq!(result.family, ModelFamily::Transformer);
assert!(result.confidence > 0.5);
}
#[test]
fn test_cnn_detected() {
let n = names(&[
"backbone.conv1.weight",
"backbone.conv2.weight",
"backbone.bn1.weight",
"backbone.pool.weight",
]);
let s = shapes_from(&[
("backbone.conv1.weight", vec![64, 3, 7, 7]),
("backbone.conv2.weight", vec![128, 64, 3, 3]),
("backbone.bn1.weight", vec![64]),
("backbone.pool.weight", vec![128]),
]);
let result = identify_family(&n, &s);
assert_eq!(result.family, ModelFamily::CNN);
assert!(result.confidence > 0.5);
}
#[test]
fn test_rnn_detected() {
let n = names(&[
"encoder.lstm.weight_ih",
"encoder.lstm.weight_hh",
"encoder.lstm.cell.weight",
"decoder.rnn.hidden.weight",
]);
let s = shapes_from(&[
("encoder.lstm.weight_ih", vec![512, 128]),
("encoder.lstm.weight_hh", vec![512, 512]),
("encoder.lstm.cell.weight", vec![512]),
("decoder.rnn.hidden.weight", vec![256, 512]),
]);
let result = identify_family(&n, &s);
assert_eq!(result.family, ModelFamily::RNN);
assert!(result.confidence > 0.5);
}
#[test]
fn test_mlp_detected() {
let n = names(&[
"classifier.fc1.weight",
"classifier.fc1.bias",
"classifier.fc2.weight",
"classifier.fc2.bias",
"classifier.fc3.weight",
"classifier.linear.weight",
]);
let s = shapes_from(&[
("classifier.fc1.weight", vec![256, 784]),
("classifier.fc1.bias", vec![256]),
("classifier.fc2.weight", vec![128, 256]),
("classifier.fc2.bias", vec![128]),
("classifier.fc3.weight", vec![10, 128]),
("classifier.linear.weight", vec![10, 128]),
]);
let result = identify_family(&n, &s);
assert_eq!(result.family, ModelFamily::MLP);
assert!(result.confidence > 0.5);
}
#[test]
fn test_unknown_for_random_names() {
let n = names(&["xyz_123", "foo_bar", "baz_qux"]);
let s = shapes_from(&[
("xyz_123", vec![10]),
("foo_bar", vec![20]),
("baz_qux", vec![30]),
]);
let result = identify_family(&n, &s);
assert_eq!(result.family, ModelFamily::Unknown);
}
#[test]
fn test_confidence_bounded_zero_to_one() {
let n = names(&[
"layers.0.self_attn.q_proj",
"layers.0.self_attn.k_proj",
"layers.0.self_attn.v_proj",
]);
let s = shapes_from(&[
("layers.0.self_attn.q_proj", vec![768, 768]),
("layers.0.self_attn.k_proj", vec![768, 768]),
("layers.0.self_attn.v_proj", vec![768, 768]),
]);
let result = identify_family(&n, &s);
assert!(result.confidence >= 0.0);
assert!(result.confidence <= 1.0);
}
#[test]
fn test_evidence_populated_for_matches() {
let n = names(&["layer.attn.q_proj.weight", "layer.attn.k_proj.weight"]);
let s = shapes_from(&[
("layer.attn.q_proj.weight", vec![768, 768]),
("layer.attn.k_proj.weight", vec![768, 768]),
]);
let result = identify_family(&n, &s);
assert!(!result.evidence.is_empty());
}
#[test]
fn test_empty_tensors_returns_unknown() {
let n: Vec<String> = vec![];
let s: Vec<(String, Vec<usize>)> = vec![];
let result = identify_family(&n, &s);
assert_eq!(result.family, ModelFamily::Unknown);
}
#[test]
fn test_display_impl_all_families() {
assert_eq!(ModelFamily::Transformer.to_string(), "Transformer");
assert_eq!(ModelFamily::CNN.to_string(), "CNN");
assert_eq!(ModelFamily::RNN.to_string(), "RNN");
assert_eq!(ModelFamily::MLP.to_string(), "MLP");
assert_eq!(ModelFamily::Unknown.to_string(), "Unknown");
}
#[test]
fn test_mixed_signals_picks_strongest() {
// Mostly transformer with one CNN tensor
let n = names(&["attn.q_proj", "attn.k_proj", "attn.v_proj", "conv1.weight"]);
let s = shapes_from(&[
("attn.q_proj", vec![768, 768]),
("attn.k_proj", vec![768, 768]),
("attn.v_proj", vec![768, 768]),
("conv1.weight", vec![64, 3, 7, 7]),
]);
let result = identify_family(&n, &s);
assert_eq!(result.family, ModelFamily::Transformer);
}
}