Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Fused Kernel Optimizations Contract

Contract: contracts/fused-kernels-v1.yaml Version: 1.0.0 Status: NEW (ALB-075 Phase 4+) Depends on: cublas-gemm-v1, training-gpu-kernel-v1, training-step-budget-v1 Source: unslothai/unsloth analysis

Equations

fused_cross_entropy

For each row r in logits [B*S, V]:
  logsumexp_r = log(sum(exp(logit[r, i])))
  loss_r = logsumexp_r - logit[r, label_r]
  grad_r[i] = exp(logit[r, i] - logsumexp_r) - delta(i, label_r)

Single kernel pass. FP32 accumulation. Softmax tensor never materialized. Backward grad overwrites logits buffer in-place (zero extra allocation).

rmsnorm_activation_reuse

Forward: save ONLY inv_var [B*S] (not normed — recompute in backward)
Backward: normed = X_cached * inv_var_saved (bit-exact recompute)
Memory savings: 24 layers * B * S * H * 4 bytes = 384 MB

swiglu_inplace_backward

d_up = grad_output * silu(gate)          → written to up buffer
d_gate = grad_output * up * silu'(gate)  → written to gate buffer

gate and up consumed before overwrite. Peak workspace reduced by 128 MB.

rope_head_grouping

Load sin/cos once per group (G=4 heads)
Apply to all heads in group with single memory load
Q: 4 groups of 4, K: 1 group of 4

Bit-exact with per-head RoPE. ~10% attention speedup from L2 cache reuse.

fused_tiled_attention

For tile_q, tile_k in tiled [0, S):
  scores_tile = Q[tile_q] @ K[tile_k]^T / sqrt(d_k)
  Online softmax (Milakov & Gimelshein 2018):
    m_new = max(m_old, max(scores_tile))
    l_new = l_old * exp(m_old - m_new) + sum(exp(scores_tile - m_new))
  O += exp(scores_tile - m_new) @ V[tile_k]
O = O / l_new

Full [S, S] attention matrix never materialized. Memory: O(BHSd_k) instead of O(BHSS). Saves 256 MB per layer.

chunked_cross_entropy (deferred)

For vocab > 65K: split logsumexp into 65K chunks. Mathematically exact (logsumexp is associative). Current vocab=32K: single chunk, no overhead.

Proof Obligations (10)

IDTypeProperty
1equivalenceFused CE matches separate CE (< 1e-5)
2invariantFused CE never allocates softmax tensor
3equivalenceRMS norm recompute is bit-exact
4boundActivation memory reduced by >= 300 MB
5equivalenceSwiGLU in-place backward correct (< 1e-5)
6equivalenceRoPE grouped matches individual (bitwise)
7equivalenceFused attention matches separate (< 1e-3)
8boundFused attention memory < separate / 4
9invariantTraining stability preserved (loss finite)
10invariantGradient flow preserved (all params)

Falsification Tests (10)

IDRulePrediction
FALSIFY-FUSED-001Fused CE matches separatemax_abs_diff(loss) < 1e-5 50 steps
FALSIFY-FUSED-002RMS norm recompute exactBitwise match all 24 layers
FALSIFY-FUSED-003SwiGLU in-place correctmax_abs_diff(d_gate, d_up) < 1e-5
FALSIFY-FUSED-004RoPE grouped matchesBit-exact 16 Q + 4 K heads
FALSIFY-FUSED-005Fused attention matchesmax_abs_diff < 1e-3 (FP32)
FALSIFY-FUSED-006Memory savings >= 300 MBActivation peak reduction measured
FALSIFY-FUSED-007No full softmax allocPeak CE memory < B*S*V*4
FALSIFY-FUSED-008Grad checkpoint exactBitwise gradient match
FALSIFY-FUSED-009Fused attn backward OKAll params get grads, loss within 1%
FALSIFY-FUSED-010No instability100 steps, loss finite, gnorm < 100

Priority Matrix

#OptimizationGainMemoryPhase
1Fused CE loss20-40ms/step-512 MB bandwidth4
2RMS norm reuse0 compute-384 MB4
3SwiGLU in-place10-20ms/step-128 MB peak4
4RoPE grouping5-10ms/step04
5Fused attention15% attn speedup-256 MB/layer5
6Chunked CEfuture0Deferred
7Grad checkpoint-2x backward-66% activations7

QA Gate

F-FUSED-001: All 10 falsification tests must pass. If combined run shows instability, bisect fusions individually to identify the culprit.