Case Study: PTX Parity Validation (GH-219)
This chapter documents the PTX parity validation system—a compile-time Poka-Yoke that catches GPU kernel generation bugs before they reach runtime. It validates that batched kernels maintain structural parity with their single-vector reference implementations.
The Problem: Batched Kernels Diverge Silently
When we added batched prefill (processing all prompt tokens in one GPU pass), we created batched variants of 6 GPU kernels. Each batched kernel must implement the same mathematical operation as its single-vector reference, just for M vectors instead of 1.
But three classes of bugs can creep in silently:
| Bug Class | Example | Impact |
|---|---|---|
| Missing batch dispatch | No ctaid.y in batched RmsNorm | All vectors processed as batch=0 |
| u64 shared memory | ld.shared.u64 [%rd4] instead of [%r4] | Wrong shared memory addressing on some GPUs |
| Wrong dispatch strategy | grid_y for GEMV instead of register_unroll | Poor memory coalescing, 10x slowdown |
Real bug found (GH-219): BatchedQ6KGemvKernel had 3 dequantization bugs:
- Wrong thread-to-value mapping (contiguous vs strided)
- Wrong ql/qh addressing (naive linear vs Q6K super-block layout)
- Wrong bit combination (
ql+4*qh-32vsql|(qh<<4)-32)
These bugs produced garbage output—but only for Q6K quantized models, and only during batched prefill. Serial prefill worked perfectly, making the bug extremely hard to catch with traditional testing.
The Solution: Structural PTX Analysis
Instead of testing numerical outputs (which requires models and is flaky), we validate the structure of generated PTX assembly at compile time.
The KernelParity Trait
/// Implemented by every batched kernel in trueno-gpu
pub trait KernelParity: Kernel {
/// Expected batch dispatch mechanism
fn expected_dispatch() -> BatchDispatch;
/// The reference (single-vector) kernel for comparison
type Reference: Kernel;
/// Validate structural parity between batched and reference PTX
fn validate_batch_dispatch(&self) -> ParityReport;
}
Two Dispatch Strategies
grid_y (ctaid.y) — For elementwise kernels (RmsNorm, ResidualAdd, RoPE, SwiGLU):
// Single-vector: grid.x covers the hidden dimension
kernel_rmsnorm<<<grid_x, block>>>(input, output, eps);
// Batched: grid.y selects which vector in the batch
kernel_batched_rmsnorm<<<(grid_x, batch_size), block>>>(input, output, eps);
// PTX: mov.u32 %r_batch, %ctaid.y;
register_unroll (m_dim) — For quantized GEMV kernels (Q4K, Q6K):
// Single-vector: one output row per thread block
kernel_q4k_gemv<<<n_rows, block>>>(input, weights, output);
// Batched: M output rows, each block handles one row across all batch elements
kernel_batched_q4k_gemv<<<n_rows, block>>>(input, weights, output, m_dim);
// PTX: ld.param.u32 %r_m, [m_dim];
What Gets Validated
For each kernel pair, the validator checks:
- Batch dispatch mechanism exists — The PTX contains
%ctaid.y(grid_y) orm_dimparameter (register_unroll) - No u64 shared memory addressing —
st.sharedandld.sharedinstructions use[%r...](32-bit), not[%rd...](64-bit) - Dispatch strategy matches expectation — Elementwise kernels use grid_y, GEMV kernels use register_unroll
The 6 Kernel Pairs
| # | Batched Kernel | Reference | Strategy | Validates |
|---|---|---|---|---|
| 1 | BatchedVectorizedRmsNormKernel | VectorizedRmsNormKernel | grid_y | Attention/FFN layer norm |
| 2 | BatchedQ4KGemvKernel | Q4KGemvKernel | register_unroll | QKV/output/FFN projections |
| 3 | BatchedQ6KGemvKernel | Q6KGemvKernel | register_unroll | Q6K quantized models |
| 4 | BatchedResidualAddKernel | ResidualAddKernel | grid_y | Skip connections |
| 5 | BatchedRopeKernel | RopeKernel | grid_y | Rotary position embeddings |
| 6 | BatchedSwigluKernel | SwigluKernel | grid_y | FFN activation |
Integration: apr qa Gate 6
The validation runs automatically as part of the QA suite:
# Runs all 7 gates including PTX parity
apr qa model.gguf --verbose
# Output:
# Running PTX parity validation...
# ✓ PASS PTX Parity 6/6 kernel pairs passed PTX parity
# 14ms
The gate:
- Detects GGUF format from magic bytes (first 8 bytes, not the full file)
- Extracts model dimensions from GGUF metadata (
GGUFConfig::from_gguf) - Instantiates all 6 batched kernels with those dimensions
- Runs structural PTX validation on each
- Reports pass/fail with specific violations
Skip flag
# Skip PTX parity if not needed (e.g., CPU-only testing)
apr qa model.gguf --skip-ptx-parity
Running the Example
# With CUDA (validates actual PTX)
cargo run -p apr-cli --example ptx_parity_validation --features inference,cuda
# Without CUDA (shows structure only)
cargo run -p apr-cli --example ptx_parity_validation --features inference
Output:
═══════════════════════════════════════════════════════════════════
GH-219: PTX Parity Validation — Poka-Yoke for GPU Kernels
═══════════════════════════════════════════════════════════════════
┌─────────────────────────────────────────────────────────────────┐
│ Demo 1: Qwen2.5-Coder-1.5B (Q4K) — 6 Kernel Pairs │
└─────────────────────────────────────────────────────────────────┘
Model dimensions:
hidden_dim: 1536
intermediate_dim: 8960
num_heads: 12
head_dim: 128
┌──────────────────────────────────┬──────────┬──────────────────┐
│ Kernel Pair │ Status │ Dispatch │
├──────────────────────────────────┼──────────┼──────────────────┤
│ BatchedRmsNorm ↔ RmsNorm │ PASS │ grid_y │
│ BatchedQ4KGemv ↔ Q4KGemv │ PASS │ register_unroll │
│ BatchedQ6KGemv ↔ Q6KGemv │ PASS │ register_unroll │
│ BatchedResidualAdd ↔ ResidualAdd│ PASS │ grid_y │
│ BatchedRoPE ↔ RoPE │ PASS │ grid_y │
│ BatchedSwiGLU ↔ SwiGLU │ PASS │ grid_y │
└──────────────────────────────────┴──────────┴──────────────────┘
6/6 kernel pairs passed PTX parity
Toyota Way Principles Applied
Poka-Yoke (Mistake-Proofing)
The validation runs at compile time (PTX generation), not at runtime. You cannot ship a broken batched kernel because the QA gate catches it before the model runs.
Jidoka (Stop the Line)
If any kernel pair fails validation, apr qa fails the entire suite. You cannot ship a model with broken PTX parity.
Genchi Genbutsu (Go and See)
The --verbose flag shows exactly which PTX instruction violated parity, with the specific line from the generated assembly. No guessing—you see the actual problem.
Lessons Learned
- Test structure, not output — Numerical output tests are flaky and require models. Structural PTX analysis is deterministic and fast (14ms for all 6 pairs).
- Two dispatch strategies exist for a reason — Elementwise ops are embarrassingly parallel (grid_y). GEMV is memory-bound and benefits from register unrolling across the batch dimension.
- Copy dequant logic exactly — When writing a batched variant of a quantized kernel, copy the dequantization logic verbatim from the reference. The Q6K bug came from rewriting it "more cleanly."
References
- GH-219: PTX Parity Validation issue
trueno-gpu/src/kernels/parity_impls.rs— KernelParity implementations (27 tests)realizar/src/ptx_parity.rs— Wrapper module with KernelDimensions and PtxParityReportcrates/apr-cli/src/commands/qa.rs— Gate 6 implementation- Shingo, S. (1986). Zero Quality Control: Source Inspection and the Poka-Yoke System. Productivity Press.