GPU Compute Architecture Specification

Version: 1.2.0 Status: IMPLEMENTED — wgpu fallback + root cause corrected Created: 2026-03-26 Updated: 2026-03-27 GH Issues: aprender#559, entrenar#309, albor#82 Author: PAIML Engineering


Abstract

This specification defines the multi-backend GPU compute architecture for the sovereign Rust AI stack (trueno, realizar, entrenar). It addresses a critical finding: NVIDIA's PTX JIT compiler produces numerically incorrect SASS on Blackwell sm_121 (GH-559), while PyTorch's pre-compiled CUDA kernels work correctly on the same hardware. We propose a hybrid dispatch architecture that routes computation to the best available backend (wgpu, CUDA+NVRTC, or CPU) based on runtime correctness validation.

1. Problem Statement

1.1 The sm_121 JIT Bug

On NVIDIA GB10 Blackwell (sm_121), all custom PTX kernels JIT-compiled via cuModuleLoadData produce numerically incorrect results:

EvidenceValue
CUDA GPU/CPU logit cosine-0.005 (completely uncorrelated)
Individual RMSNorm kernel error5e-7 (CORRECT — within FP32 epsilon)
Individual Q4K GEMV error~1% per operation (FP32 rounding)
wgpu GPU/CPU cosine0.999863 (near-perfect parity)
PyTorch GPU/CPU cosine1.000000 (pre-compiled CUDA)
Our PTX via Python ctypes1.000000 (JIT is correct)

1.2 Root Cause (Corrected 2026-03-27)

Previous diagnosis (WRONG): "NVIDIA JIT compiler bug on sm_121." Falsified by: Loading our exact PTX via Python ctypes → cosine=1.0.

Actual root cause: FP32 non-associativity in accumulation ordering. Each Q4K GEMV kernel accumulates partial sums in parallel (32 threads × different order than CPU's sequential sum). This produces ~0.1% per-kernel rounding difference. Over 28 layers × 10+ kernels = ~280 operations:

(1.001)^280 ≈ 1.32 → 32% divergence → cosine ≈ -0.005

PyTorch avoids this because cuBLAS uses TF32/FP64 internal accumulators. wgpu avoids it because WGSL shaders use sequential accumulation matching CPU.

Fix options:

  1. wgpu (DONE) — same accumulation order as CPU, cosine=0.999863
  2. FP64 accumulation — use .f64 for GEMV partial sums in PTX
  3. Kahan compensation — compensated summation in GEMV inner loop
  4. cuBLAS fallback — pre-compiled TF32 accumulators (3.5x bandwidth cost)

1.3 Connection to Training Quality (entrenar#309)

The albor project independently discovered that entrenar training converges 21x slower than PyTorch on identical configuration (albor#82). Since the same trueno-gpu PTX kernels are used for RMSNorm in the training backward pass, wrong gradient norms compound → wrong learning trajectory.

1.4 Falsifiable Claim

The sovereign Rust AI stack can produce inference results within cosine similarity ≥0.98 of CPU on any GPU supported by wgpu (Vulkan 1.2+) or CUDA (sm_50+), without depending on NVIDIA's runtime JIT compiler.

Falsified if: wgpu inference or NVRTC-compiled CUDA produces cosine < 0.98 on any supported GPU.

2. Architecture: Hybrid Backend Dispatch

2.1 Backend Selection

#![allow(unused)]
fn main() {
let backend = if cuda_available && parity_gate_passes() {
    Backend::Cuda       // NVIDIA-only, fastest (custom Q4K GEMV)
} else if wgpu_available {
    Backend::Wgpu       // All vendors, portable (Vulkan/Metal/DX12)
} else {
    Backend::Cpu        // Always works (SIMD-accelerated)
};
}

The existing parity gate (validate_gpu_first_token + cosine similarity ≥0.98) serves as the runtime correctness validator. Toyota Way: the gate detects the bug, the system routes around it automatically. No env vars, no workarounds.

2.2 Backend Capabilities

CapabilityCPU (trueno SIMD)wgpu (Vulkan)CUDA PTX (JIT)CUDA NVRTC
Vendor supportAllAMD, Intel, NVIDIA, AppleNVIDIA onlyNVIDIA only
Q4K GEMVAVX2/NEONWGSL compute shaderCustom PTXCustom PTX
Bandwidth efficiencyN/A (CPU)~80-85% peak~95% peak~95% peak
Tensor CoresNoLimited (coop matrices)Full (WMMA PTX)Full
CompilationAhead-of-timeDriver shader compilerRuntime JITNVRTC library
sm_121 correctYesYes (Vulkan compiler)No (JIT bug)Expected yes
DependencyNoneVulkan driverCUDA driverCUDA toolkit
Provable contractsYesYesYesYes

2.3 Performance Budget

For single-token decode (M=1), the dominant cost is memory bandwidth (loading model weights). Compute intensity is low — the GPU is bandwidth-bound.

Q4K weight bytes per token:  7.2 GB (7B model)
FP16 weight bytes per token: 25.2 GB (3.5x more)

GB10 memory bandwidth: 273 GB/s (unified memory)

Theoretical minimum latency:
  Q4K (custom kernel):  7.2 / 273 = 26 ms/token (38 tok/s)
  FP16 (cuBLAS):       25.2 / 273 = 92 ms/token (11 tok/s)
BackendRead efficiencyExpected tok/svs cuBLAS
CUDA Q4K GEMV95%~363.3x faster
wgpu Q4K WGSL80%~302.7x faster
cuBLAS FP16100% (but 3.5x data)~11baseline
CPU SIMDN/A~30.3x

Key insight from Ivanov et al. (2021) "Data Movement Is All You Need": For autoregressive LLM inference, the arithmetic intensity is below the roofline knee — performance is determined by memory bandwidth, not FLOPs. A kernel that reads quantized data directly (Q4K = 0.5625 B/elem) beats a kernel that reads dequantized data (FP16 = 2.0 B/elem) by the bandwidth ratio, regardless of compute optimizations.

3. wgpu Inference Path

3.1 Current Status

The wgpu inference kernels are individually implemented in trueno:

KernelPMATWGSL ShaderStatus
RMSNormPMAT-336rmsnorm_shaderDone
Q4K dequant+GEMVPMAT-363q4k_gemv_shaderDone
Bias addPMAT-356bias_add_shaderDone
RoPEPMAT-358rope_shaderDone
AttentionPMAT-361attention_shaderDone
LM HeadPMAT-347lm_head_shaderDone
SwiGLU/SiLUPMAT-346silu_shaderDone (overflow fixed)
KV CachePMAT-344kv_cache_shaderPartial
End-to-end forwardPMAT-037wgpu_parity_test.rsPASS: cosine=0.999863

3.2 Completion Plan

Wire the individual shaders into a complete forward_wgpu() function in realizar that can serve as a drop-in replacement for forward_gpu_resident():

#![allow(unused)]
fn main() {
// In realizar/src/gguf/cuda/mod.rs (or new wgpu module)
pub fn forward_wgpu_resident(
    &mut self,
    token_id: u32,
    cache: &mut OwnedQuantizedKVCache,
    position: usize,
) -> Result<Vec<f32>> {
    // 1. Embed token (CPU)
    let embed = self.model.embed(&[token_id]);

    // 2. Upload to GPU via wgpu
    let hidden = self.wgpu_device.upload(&embed);

    // 3. For each layer: RMSNorm → QKV → RoPE → Attention → OProj → Residual → FFN → Residual
    for layer_idx in 0..self.model.config.num_layers {
        hidden = self.wgpu_transformer_layer(hidden, layer_idx, position)?;
    }

    // 4. Output RMSNorm → LM Head → download logits
    let normed = self.wgpu_rmsnorm(hidden, &self.output_norm_gamma)?;
    let logits = self.wgpu_lm_head(normed)?;
    logits.download()
}
}

3.3 wgpu Compute Shader Limitations

Relevant to performance parity with CUDA:

No warp shuffle equivalent. Vulkan subgroup operations (subgroupAdd, subgroupBroadcast) provide similar functionality but with vendor-variable subgroup sizes (32 on NVIDIA, 64 on AMD, variable on Intel). Design reduction algorithms for any subgroup size.

Reference: Xu et al. (2024) "Efficient Parallel Reductions on GPUs using Subgroup Operations" — demonstrates that subgroup-based reductions achieve 90-95% of warp-shuffle performance when subgroup size is known at compile time.

No explicit shared memory. Vulkan workgroup shared memory is declared in WGSL (var<workgroup>) but the driver controls banking and allocation. Less control than CUDA's configurable shared memory. Sufficient for RMSNorm reductions and tiled GEMV.

No tensor core access (yet). Vulkan cooperative matrices (VK_KHR_cooperative_matrix) expose tensor cores but adoption is limited. For M=1 decode this doesn't matter — tensor cores help at M≥4 prefill.

4. CUDA Fix Strategy: NVRTC

4.1 Approach

Replace the driver JIT path with NVRTC (NVIDIA Runtime Compilation Library) for sm_120+ GPUs:

Current (broken):
  Rust → PTX string → cuModuleLoadData → driver JIT → wrong SASS

Fixed:
  Rust → PTX string → nvrtcCompileProgram(--gpu-architecture=sm_121)
                     → cubin → cuModuleLoadData → correct SASS

NVRTC uses the same compiler backend as nvcc — the full optimizing compiler, not the lightweight driver JIT.

4.2 Implementation

#![allow(unused)]
fn main() {
// In trueno-gpu/src/driver/module.rs
pub fn from_ptx_nvrtc(ctx: &CudaContext, ptx: &str) -> Result<Self, GpuError> {
    let (major, minor) = ctx.compute_capability()?;

    // Load NVRTC dynamically (optional dependency)
    let nvrtc = dlopen("libnvrtc.so")?;

    // Compile PTX → cubin for exact target architecture
    let target = format!("--gpu-architecture=compute_{}{}", major, minor);
    let program = nvrtc.create_program(ptx, "kernel.ptx")?;
    nvrtc.compile_program(program, &[&target])?;

    // Load compiled cubin (no JIT)
    let cubin = nvrtc.get_cubin(program)?;
    let mut module = ptr::null_mut();
    cuModuleLoadData(&mut module, cubin.as_ptr())?;

    Ok(Self { module, functions: HashMap::new() })
}
}

4.3 Pros and Cons

ProCon
Fixes sm_121 without losing Q4K speedRequires libnvrtc.so (~100 MB)
Same PTX source, same provable contracts2-5x slower first-run compilation
Compile-once, cache cubin foreverABI coupled to CUDA toolkit version
Offline testable (CI validation)NVIDIA-only (doesn't help wgpu)
Explicit sm_121 targetAdds ~10 new FFI bindings

4.4 Hybrid Loading Strategy

#![allow(unused)]
fn main() {
pub fn from_ptx(ctx: &CudaContext, ptx: &str) -> Result<Self, GpuError> {
    let (major, _) = ctx.compute_capability()?;

    if major >= 12 {
        // Blackwell+: prefer NVRTC (bypasses buggy JIT)
        if let Ok(module) = Self::from_ptx_nvrtc(ctx, ptx) {
            return Ok(module);
        }
        // NVRTC unavailable: fall back to wgpu (via caller)
        return Err(GpuError::NvrtcUnavailable);
    }

    // Pre-Blackwell: driver JIT works correctly
    Self::from_ptx_jit(ctx, ptx)
}
}

5. Parity Gate Architecture

5.1 Multi-Backend Validation

The parity gate validates correctness at model load time by comparing a one-token forward pass between the candidate GPU backend and CPU:

              ┌─────────────┐
              │  Load Model  │
              └──────┬───────┘
                     │
              ┌──────▼───────┐
              │ CPU Forward   │ ← reference (always correct)
              │ (1 token)     │
              └──────┬───────┘
                     │
         ┌───────────┼───────────┐
         │           │           │
   ┌─────▼─────┐┌───▼───┐┌─────▼─────┐
   │CUDA Forward││ wgpu  ││  cuBLAS   │
   │ (1 token)  ││Forward││ (fallback)│
   └─────┬─────┘└───┬───┘└─────┬─────┘
         │          │           │
   cosine ≥ 0.98?  cosine?    cosine?
         │          │           │
         └───────use best───────┘
              passing backend

5.2 Contract Enforcement

Full provable contract: ../provable-contracts/contracts/gpu-multi-backend-parity-v1.yaml

4 equations:

EquationFormulaStatus
multi_backend_parityexists b: cosine(forward(b), forward(cpu)) >= 0.98Enforced
backend_priorityselect = first(b in [cuda, wgpu, cpu] where parity >= 0.98)Enforced
bandwidth_bound_theoremlatency >= model_bytes / bandwidth (Ivanov 2021)Proven
jit_compilation_correctnesscosine(jit_sass, ref_sass) >= 0.9999Violated sm_121

6 proof obligations: parity exists, no garbage serving, determinism, wgpu equiv, NVRTC equiv, Q4K bandwidth bound.

7 falsification tests (F-MBP-001..007): wgpu parity, NVRTC parity, PyTorch canary, pre-Blackwell JIT, Q4K advantage, Toyota Way (no silent garbage), driver update.

2 Kani harnesses: backend selection determinism, failed backend exclusion.

Five-whys embedded in contract YAML for audit trail (GH-559 root cause → NVIDIA JIT bug).

See also:

  • gpu-context-health-v1.yaml — FP8 architecture guard (GH-542)
  • ptx-target-parity-v1.yaml — PTX .target directive (violated on sm_121)
  • gqa-kernel-v1.yaml — GQA attention correctness
# Key falsification test from gpu-multi-backend-parity-v1.yaml:
- id: F-PARITY-001
  rule: "wgpu parity on sm_121"
    prediction: "cosine(wgpu_forward, cpu_forward) >= 0.98 on GB10"
    test: "Run canary with wgpu backend on gx10"
    if_fails: "wgpu Vulkan shader compiler also has sm_121 issues"

  - id: F-PARITY-002
    rule: "NVRTC parity on sm_121"
    prediction: "cosine(nvrtc_forward, cpu_forward) >= 0.98 on GB10"
    test: "Run canary with NVRTC-compiled CUDA on gx10"
    if_fails: "NVRTC compiler also produces wrong sm_121 SASS"

6. Scientific References

  1. Ivanov et al. (2021) "Data Movement Is All You Need: A Case Study on Optimizing Transformers." MLSys 2021. — Establishes that transformer inference is memory-bandwidth bound, not compute bound. Quantized kernels (reading less data) outperform dense kernels (more FLOPs but more data movement).

  2. Dettmers et al. (2022) "GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers." — INT4/Q4K quantization preserves model quality while reducing memory footprint 4x. Our Q4K GEMV kernels implement this in custom PTX and WGSL.

  3. Frantar et al. (2023) "SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot." — Wanda pruning (used in our pipeline) achieves target sparsity with minimal quality loss.

  4. Lin et al. (2024) "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration." — Per-channel quantization scales (related to our Q4K super-block format) improve quantization quality.

  5. NVIDIA PTX ISA (2024) "Parallel Thread Execution ISA Version 8.5." — Specifies forward compatibility: PTX compiled for sm_90 must run correctly on sm_121 via JIT. Our finding (GH-559) demonstrates a violation of this specification.

  6. Ainslie et al. (2023) "GQA: Training Generalized Multi-Query Attention Models." — Grouped Query Attention used by Qwen2.5. Our provable contract gqa-kernel-v1.yaml verifies this.

7. Implementation Roadmap

PhaseWorkPriorityStatus
1Wire wgpu end-to-end forward in realizarCriticalDONEtry_apr_wgpu_inference in gguf_gpu_generate.rs
2Run parity gate on wgpu (F-PARITY-001)CriticalDONE — cosine=0.999863 on sm_121
3Smart backend dispatch in realizarMediumDONE — CUDA → wgpu → CPU auto-fallback
4Wire wgpu into batch path (GH-560)CriticalDONE — GH-560 FIXED (2026-03-28). 84.15% HumanEval on wgpu batch.
5Push trueno to unblock Q4K wgpu shaderCriticalDONE — 51 lint errors fixed, pushed to origin, gx10 updated
6Fix CUDA FP32 precision (GH-561)Highf64 accumulators in 6 backward GEMM variants. Training verified: loss 13.61→12.02.
7Benchmark wgpu vs CUDA vs cuBLASLowPlanned

8. Memory Analysis (2026-04-04)

8.1 LoRA Merge Memory Profile

The apr finetune --merge operation holds the full FP32 model in memory:

ComponentMemory
Q4K base model (7B)7.5 GB (compressed)
FP32 dequantized base~28 GB
FP32 output model~28 GB
LoRA adapter40 MB
Working memory~5 GB
Peak RSS~49 GB

Finding (2026-04-04): Merge OOM-killed twice on gx10 when running concurrently with N-sampling (18 GB). 49 + 18 + 15 (system) = 82 GB — should fit in 119 GB, but zram swap compression on FP32 data is poor, reducing effective swap from 32 GB to ~16 GB. OOM killer triggered at anon-rss=48.9 GB.

Resolution: Merge must run solo on gx10 (not concurrent with batch inference). Auto-merge pipeline (PID 1886069) queued to run after N-sampling completes.

8.2 Batch Inference Memory Profile

ComponentMemory
Q4K model (7B)7.5 GB (mmap)
KV cache (512 tokens)~1 GB
Working buffers~10 GB
Steady-state RSS~18.6 GB

Batch inference is memory-stable at 18.6 GB across 1640+ prompts. No memory leak detected over 16h continuous operation.