From 7d28c2ec9a8d0ef55cbf5540533a301547b9e456 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Fri, 15 May 2026 08:52:11 -0700 Subject: [PATCH 1/5] perf(kernel): boolean-inversion opt + dedup u64_add MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Check.lean safe_refs_only: replace `match u { 0=>1, 1=>0 }` with `1 - is_unsafe_ci(ci)` — pure arithmetic, drops the match's selectors. - ByteStream.u64_add now returns (U64, G), exposing the final carry-out; removes the kernel's duplicate u64_add_with_carry. klimbs_succ / klimbs_add_carry / klimbs_mul_single call u64_add directly. u64_add had no prior callers, so no ripple to Blake3 / IxonSerialize. --- Ix/IxVM/ByteStream.lean | 12 ++++++---- Ix/IxVM/Kernel/Check.lean | 6 +---- Ix/IxVM/Kernel/Primitive.lean | 45 ++++------------------------------- 3 files changed, 13 insertions(+), 50 deletions(-) diff --git a/Ix/IxVM/ByteStream.lean b/Ix/IxVM/ByteStream.lean index f3cc1d93..d657a890 100644 --- a/Ix/IxVM/ByteStream.lean +++ b/Ix/IxVM/ByteStream.lean @@ -99,8 +99,9 @@ def byteStream := ⟦ } } - -- `u64` addition with carry propagation (little-endian bytes) - fn u64_add(a: U64, b: U64) -> U64 { + -- `u64` addition with carry propagation (little-endian bytes). + -- Returns the sum together with the final carry-out. + fn u64_add(a: U64, b: U64) -> (U64, G) { let [a0, a1, a2, a3, a4, a5, a6, a7] = a; let [b0, b1, b2, b3, b4, b5, b6, b7] = b; let (s0, c1) = u8_add(a0, b0); @@ -122,9 +123,10 @@ def byteStream := ⟦ let (t6, o6) = u8_add(a6, b6); let (s6, c6a) = u8_add(t6, c6); let c7 = u8_xor(o6, c6a); - let (t7, _) = u8_add(a7, b7); - let (s7, _) = u8_add(t7, c7); - [s0, s1, s2, s3, s4, s5, s6, s7] + let (t7, o7) = u8_add(a7, b7); + let (s7, c7a) = u8_add(t7, c7); + let final_carry = u8_xor(o7, c7a); + ([s0, s1, s2, s3, s4, s5, s6, s7], final_carry) } -- `u64` subtraction via repeated decrement (correct for small b) diff --git a/Ix/IxVM/Kernel/Check.lean b/Ix/IxVM/Kernel/Check.lean index 648b2a47..7f14e84c 100644 --- a/Ix/IxVM/Kernel/Check.lean +++ b/Ix/IxVM/Kernel/Check.lean @@ -59,11 +59,7 @@ def check := ⟦ KExprNode.Srt(_) => 1, KExprNode.Const(idx, _) => let ci = load(list_lookup(top, idx)); - let u = is_unsafe_ci(ci); - match u { - 0 => 1, - 1 => 0, - }, + 1 - is_unsafe_ci(ci), KExprNode.App(f, a) => safe_refs_only(f, top) * safe_refs_only(a, top), KExprNode.Lam(t, b) => diff --git a/Ix/IxVM/Kernel/Primitive.lean b/Ix/IxVM/Kernel/Primitive.lean index ed571090..383da8b7 100644 --- a/Ix/IxVM/Kernel/Primitive.lean +++ b/Ix/IxVM/Kernel/Primitive.lean @@ -693,48 +693,13 @@ def primitive := ⟦ 0xa4, 0x00, 0xfb, 0x06, 0x03, 0xae, 0xdf, 0x2f] } - -- Mirror: `u64_add` in ByteStream.lean expanded to expose final - -- carry. Used by klimbs_succ / klimbs_add. - -- - -- TODO: delete this once `ByteStream.lean::u64_add` is patched to - -- return `(U64, G)` and existing call sites updated. Tracking this - -- as a follow-up because the patch ripples beyond the kernel - -- (Blake3 / IxonSerialize / ByteStream itself). - fn u64_add_with_carry(a: U64, b: U64) -> (U64, G) { - let [a0, a1, a2, a3, a4, a5, a6, a7] = a; - let [b0, b1, b2, b3, b4, b5, b6, b7] = b; - let (s0, c1) = u8_add(a0, b0); - let (t1, o1) = u8_add(a1, b1); - let (s1, c1a) = u8_add(t1, c1); - let c2 = u8_xor(o1, c1a); - let (t2, o2) = u8_add(a2, b2); - let (s2, c2a) = u8_add(t2, c2); - let c3 = u8_xor(o2, c2a); - let (t3, o3) = u8_add(a3, b3); - let (s3, c3a) = u8_add(t3, c3); - let c4 = u8_xor(o3, c3a); - let (t4, o4) = u8_add(a4, b4); - let (s4, c4a) = u8_add(t4, c4); - let c5 = u8_xor(o4, c4a); - let (t5, o5) = u8_add(a5, b5); - let (s5, c5a) = u8_add(t5, c5); - let c6 = u8_xor(o5, c5a); - let (t6, o6) = u8_add(a6, b6); - let (s6, c6a) = u8_add(t6, c6); - let c7 = u8_xor(o6, c6a); - let (t7, o7) = u8_add(a7, b7); - let (s7, c7a) = u8_add(t7, c7); - let final_carry = u8_xor(o7, c7a); - ([s0, s1, s2, s3, s4, s5, s6, s7], final_carry) - } - -- Mirror: BigUint::succ. Increment a KLimbs by 1; ripple carry. fn klimbs_succ(n: KLimbs) -> KLimbs { match load(n) { ListNode.Nil => store(ListNode.Cons([1, 0, 0, 0, 0, 0, 0, 0], store(ListNode.Nil))), ListNode.Cons(limb, rest) => - let pair = u64_add_with_carry(limb, [1, 0, 0, 0, 0, 0, 0, 0]); + let pair = u64_add(limb, [1, 0, 0, 0, 0, 0, 0, 0]); match pair { (sum, carry) => match carry { @@ -764,10 +729,10 @@ def primitive := ⟦ _ => klimbs_succ(a), }, ListNode.Cons(lb, rb) => - let pair1 = u64_add_with_carry(la, lb); + let pair1 = u64_add(la, lb); match pair1 { (sum1, carry1) => - let pair2 = u64_add_with_carry(sum1, [carry, 0, 0, 0, 0, 0, 0, 0]); + let pair2 = u64_add(sum1, [carry, 0, 0, 0, 0, 0, 0, 0]); match pair2 { (sum2, carry2) => let total_carry = g_or(carry1, carry2); @@ -1016,9 +981,9 @@ def primitive := ⟦ ListNode.Cons(b_limb, rest) => match u64_mul(a_limb, b_limb) { (lo, hi) => - match u64_add_with_carry(lo, carry) { + match u64_add(lo, carry) { (sum, carry_out) => - match u64_add_with_carry(hi, [carry_out, 0, 0, 0, 0, 0, 0, 0]) { + match u64_add(hi, [carry_out, 0, 0, 0, 0, 0, 0, 0]) { (new_carry, _) => let new_acc = list_snoc(acc, sum); klimbs_mul_single(a_limb, rest, new_carry, new_acc), From 6cae736e4823da961f760961441b457b96a5c217 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Fri, 15 May 2026 09:46:24 -0700 Subject: [PATCH 2/5] feat(aiur): add u8_mul gadget New `u8_mul(a, b)` Aiur primitive: byte * byte -> (low, high), with low + 256*high = a*b. Slots into the existing Bytes2 preprocessed 256x256 table as two extra columns plus a lookup channel, so its cost class matches u8_add (one lookup, two auxiliaries). Wires the op through the Rust prover (bytecode, bytes2 gadget, constraints, trace, execute, FFI) and every Lean Aiur stage (Source through Bytecode), compiler, semantics, and Meta syntax. FFI ctor tag 18 is u8Mul; subsequent Op tags shift by one. Verified via the aiur and aiur-cross suites: prove/verify passes for 45*131 = (7, 23) and 255*255 = (1, 254). --- Ix/Aiur/Compiler/Check.lean | 5 +++ Ix/Aiur/Compiler/Concretize.lean | 11 +++++- Ix/Aiur/Compiler/Layout.lean | 2 +- Ix/Aiur/Compiler/Lower.lean | 3 ++ Ix/Aiur/Compiler/Match.lean | 1 + Ix/Aiur/Goldilocks.lean | 4 ++ Ix/Aiur/Interpret.lean | 8 ++++ Ix/Aiur/Meta.lean | 7 ++++ Ix/Aiur/Semantics/BytecodeEval.lean | 5 +++ Ix/Aiur/Semantics/SourceEval.lean | 8 ++++ Ix/Aiur/Stages/Bytecode.lean | 1 + Ix/Aiur/Stages/Concrete.lean | 5 ++- Ix/Aiur/Stages/Simple.lean | 5 ++- Ix/Aiur/Stages/Source.lean | 1 + Ix/Aiur/Stages/Typed.lean | 5 ++- Tests/Aiur/Aiur.lean | 5 +++ Tests/Aiur/Cross.lean | 3 ++ src/aiur.rs | 5 +++ src/aiur/bytecode.rs | 1 + src/aiur/constraints.rs | 14 +++++-- src/aiur/execute.rs | 8 ++++ src/aiur/gadgets/bytes2.rs | 58 ++++++++++++++++++++++++++--- src/aiur/trace.rs | 17 +++++++-- src/ffi/aiur/toplevel.rs | 14 ++++--- 24 files changed, 170 insertions(+), 26 deletions(-) diff --git a/Ix/Aiur/Compiler/Check.lean b/Ix/Aiur/Compiler/Check.lean index 54a9635c..6970abbb 100644 --- a/Ix/Aiur/Compiler/Check.lean +++ b/Ix/Aiur/Compiler/Check.lean @@ -735,6 +735,10 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with let a' ← checkNoEscape a .field let b' ← checkNoEscape b .field pure (Typed.Term.u8Add (.tuple #[.field, .field]) false a' b') + | .u8Mul a b => do + let a' ← checkNoEscape a .field + let b' ← checkNoEscape b .field + pure (Typed.Term.u8Mul (.tuple #[.field, .field]) false a' b') | .u8Sub a b => do let a' ← checkNoEscape a .field let b' ← checkNoEscape b .field @@ -900,6 +904,7 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with | .u8ShiftRight τ e a => do pure (.u8ShiftRight (← zonkTyp τ) e (← zonkTypedTerm a)) | .u8Xor τ e a b => do pure (.u8Xor (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) | .u8Add τ e a b => do pure (.u8Add (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) + | .u8Mul τ e a b => do pure (.u8Mul (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) | .u8Sub τ e a b => do pure (.u8Sub (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) | .u8And τ e a b => do pure (.u8And (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) | .u8Or τ e a b => do pure (.u8Or (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) diff --git a/Ix/Aiur/Compiler/Concretize.lean b/Ix/Aiur/Compiler/Concretize.lean index 152ceb72..b9924aee 100644 --- a/Ix/Aiur/Compiler/Concretize.lean +++ b/Ix/Aiur/Compiler/Concretize.lean @@ -325,6 +325,9 @@ def termToConcrete | .u8Add τ e a b => do pure (.u8Add (← typToConcrete mono τ) e (← termToConcrete mono a) (← termToConcrete mono b)) + | .u8Mul τ e a b => do + pure (.u8Mul (← typToConcrete mono τ) e + (← termToConcrete mono a) (← termToConcrete mono b)) | .u8Sub τ e a b => do pure (.u8Sub (← typToConcrete mono τ) e (← termToConcrete mono a) (← termToConcrete mono b)) @@ -521,6 +524,8 @@ def rewriteTypedTerm (decls : Typed.Decls) (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) | .u8Add τ e a b => .u8Add (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) + | .u8Mul τ e a b => .u8Mul (rewriteTyp subst mono τ) e + (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) | .u8Sub τ e a b => .u8Sub (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) | .u8And τ e a b => .u8And (rewriteTyp subst mono τ) e @@ -597,7 +602,7 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) : let seen := tArgs.foldl collectInTyp seen args.attach.foldl (fun s ⟨a, _⟩ => collectInTypedTerm s a) seen | .add τ _ a b | .sub τ _ a b | .mul τ _ a b - | .u8Xor τ _ a b | .u8Add τ _ a b | .u8Sub τ _ a b + | .u8Xor τ _ a b | .u8Add τ _ a b | .u8Mul τ _ a b | .u8Sub τ _ a b | .u8And τ _ a b | .u8Or τ _ a b | .u8LessThan τ _ a b | .u32LessThan τ _ a b => collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) b @@ -659,7 +664,7 @@ def collectCalls (decls : Typed.Decls) let seen := collectCalls decls seen scrut bs.attach.foldl (fun s ⟨(_, b), _⟩ => collectCalls decls s b) seen | .add _ _ a b | .sub _ _ a b | .mul _ _ a b - | .u8Xor _ _ a b | .u8Add _ _ a b | .u8Sub _ _ a b + | .u8Xor _ _ a b | .u8Add _ _ a b | .u8Mul _ _ a b | .u8Sub _ _ a b | .u8And _ _ a b | .u8Or _ _ a b | .u8LessThan _ _ a b | .u32LessThan _ _ a b => collectCalls decls (collectCalls decls seen a) b @@ -749,6 +754,8 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term (substInTypedTerm subst a) (substInTypedTerm subst b) | .u8Add τ e a b => .u8Add (Typ.instantiate subst τ) e (substInTypedTerm subst a) (substInTypedTerm subst b) + | .u8Mul τ e a b => .u8Mul (Typ.instantiate subst τ) e + (substInTypedTerm subst a) (substInTypedTerm subst b) | .u8Sub τ e a b => .u8Sub (Typ.instantiate subst τ) e (substInTypedTerm subst a) (substInTypedTerm subst b) | .u8And τ e a b => .u8And (Typ.instantiate subst τ) e diff --git a/Ix/Aiur/Compiler/Layout.lean b/Ix/Aiur/Compiler/Layout.lean index 5c425ea8..4b3164ff 100644 --- a/Ix/Aiur/Compiler/Layout.lean +++ b/Ix/Aiur/Compiler/Layout.lean @@ -185,7 +185,7 @@ def opLayout : Bytecode.Op → LayoutM Unit | .u8BitDecomposition _ => do pushDegrees $ .replicate 8 1; bumpAuxiliaries 8; bumpLookups | .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. | .u8And .. | .u8Or .. => do pushDegree 1; bumpAuxiliaries; bumpLookups - | .u8Add .. | .u8Sub .. => do pushDegrees #[1, 1]; bumpAuxiliaries 2; bumpLookups + | .u8Add .. | .u8Mul .. | .u8Sub .. => do pushDegrees #[1, 1]; bumpAuxiliaries 2; bumpLookups | .u8LessThan .. => do pushDegree 1; bumpAuxiliaries; bumpLookups | .u32LessThan .. => do pushDegree 1; bumpAuxiliaries 12; bumpLookups 6 | .debug .. => pure () diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index 43c98267..a50f4ef2 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -283,6 +283,9 @@ def toIndex | .u8Add _ _ i j => do let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j pushOp (.u8Add i j) 2 + | .u8Mul _ _ i j => do + let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j + pushOp (.u8Mul i j) 2 | .u8Sub _ _ i j => do let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j pushOp (.u8Sub i j) 2 diff --git a/Ix/Aiur/Compiler/Match.lean b/Ix/Aiur/Compiler/Match.lean index 1e148b83..3a92d00e 100644 --- a/Ix/Aiur/Compiler/Match.lean +++ b/Ix/Aiur/Compiler/Match.lean @@ -380,6 +380,7 @@ def typedToSimple : Term → Simple.Term | .u8ShiftRight τ e a => .u8ShiftRight τ e (typedToSimple a) | .u8Xor τ e a b => .u8Xor τ e (typedToSimple a) (typedToSimple b) | .u8Add τ e a b => .u8Add τ e (typedToSimple a) (typedToSimple b) + | .u8Mul τ e a b => .u8Mul τ e (typedToSimple a) (typedToSimple b) | .u8Sub τ e a b => .u8Sub τ e (typedToSimple a) (typedToSimple b) | .u8And τ e a b => .u8And τ e (typedToSimple a) (typedToSimple b) | .u8Or τ e a b => .u8Or τ e (typedToSimple a) (typedToSimple b) diff --git a/Ix/Aiur/Goldilocks.lean b/Ix/Aiur/Goldilocks.lean index 937bab60..3a98fdcd 100644 --- a/Ix/Aiur/Goldilocks.lean +++ b/Ix/Aiur/Goldilocks.lean @@ -56,6 +56,10 @@ def G.u8LessThan (a b : G) : G := if a.n < b.n then 1 else 0 def G.u8Add (a b : G) : G × G := (G.ofNat ((a.n + b.n) % 256), G.ofNat ((a.n + b.n) / 256)) +/-- u8 multiplication returns `(low byte, high byte)`. -/ +def G.u8Mul (a b : G) : G × G := + (G.ofNat ((a.n * b.n) % 256), G.ofNat ((a.n * b.n) / 256)) + /-- u8 subtraction returns `(result % 256, borrow)`. -/ def G.u8Sub (a b : G) : G × G := (G.ofNat ((a.n + 256 - b.n) % 256), if a.n < b.n then 1 else 0) diff --git a/Ix/Aiur/Interpret.lean b/Ix/Aiur/Interpret.lean index 45b788f5..d8a7a294 100644 --- a/Ix/Aiur/Interpret.lean +++ b/Ix/Aiur/Interpret.lean @@ -326,6 +326,14 @@ partial def interp (decls : Decls) (bindings : Bindings) : Term → InterpM Valu let overflow : Value := .field (if x >= 256 then 1 else 0) return .tuple #[sum, overflow] | _, _ => throwErr "u8Add: expected field values" + | .u8Mul t1 t2 => do + match (← interp decls bindings t1), (← interp decls bindings t2) with + | .field a, .field b => + let x := a.val.toUInt8.toNat * b.val.toUInt8.toNat + let lo : Value := .field (G.ofUInt8 x.toUInt8) + let hi : Value := .field (G.ofUInt8 (x / 256).toUInt8) + return .tuple #[lo, hi] + | _, _ => throwErr "u8Mul: expected field values" | .u8Sub t1 t2 => do match (← interp decls bindings t1), (← interp decls bindings t2) with | .field a, .field b => diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index 5a5cb2d9..c85a5b94 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -173,6 +173,7 @@ syntax "u8_shift_left" "(" aiur_trm ")" : a syntax "u8_shift_right" "(" aiur_trm ")" : aiur_trm syntax "u8_xor" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u8_add" "(" aiur_trm ", " aiur_trm ")" : aiur_trm +syntax "u8_mul" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u8_sub" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u8_and" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u8_or" "(" aiur_trm ", " aiur_trm ")" : aiur_trm @@ -292,6 +293,8 @@ partial def elabTrm : ElabStxCat `aiur_trm mkAppM ``Source.Term.u8Xor #[← elabTrm i, ← elabTrm j] | `(aiur_trm| u8_add($i:aiur_trm, $j:aiur_trm)) => do mkAppM ``Source.Term.u8Add #[← elabTrm i, ← elabTrm j] + | `(aiur_trm| u8_mul($i:aiur_trm, $j:aiur_trm)) => do + mkAppM ``Source.Term.u8Mul #[← elabTrm i, ← elabTrm j] | `(aiur_trm| u8_sub($i:aiur_trm, $j:aiur_trm)) => do mkAppM ``Source.Term.u8Sub #[← elabTrm i, ← elabTrm j] | `(aiur_trm| u8_and($i:aiur_trm, $j:aiur_trm)) => do @@ -480,6 +483,10 @@ where let i ← replaceToken old new i let j ← replaceToken old new j `(aiur_trm| u8_add($i, $j)) + | `(aiur_trm| u8_mul($i:aiur_trm, $j:aiur_trm)) => do + let i ← replaceToken old new i + let j ← replaceToken old new j + `(aiur_trm| u8_mul($i, $j)) | `(aiur_trm| u8_sub($i:aiur_trm, $j:aiur_trm)) => do let i ← replaceToken old new i let j ← replaceToken old new j diff --git a/Ix/Aiur/Semantics/BytecodeEval.lean b/Ix/Aiur/Semantics/BytecodeEval.lean index b1669c37..a45c377e 100644 --- a/Ix/Aiur/Semantics/BytecodeEval.lean +++ b/Ix/Aiur/Semantics/BytecodeEval.lean @@ -247,6 +247,11 @@ def evalOp (t : Bytecode.Toplevel) (fuel : Nat) (op : Op) (st : EvalState) : let sum := x.val.toUInt8.toNat + y.val.toUInt8.toNat let st1 := pushMap st (G.ofUInt8 sum.toUInt8) pure (pushMap st1 (if sum ≥ 256 then 1 else 0)) + | .u8Mul a b => do + let x ← readIdx st a; let y ← readIdx st b + let prod := x.val.toUInt8.toNat * y.val.toUInt8.toNat + let st1 := pushMap st (G.ofUInt8 prod.toUInt8) + pure (pushMap st1 (G.ofUInt8 (prod / 256).toUInt8)) | .u8Sub a b => do let x ← readIdx st a; let y ← readIdx st b let i := x.val.toUInt8; let j := y.val.toUInt8 diff --git a/Ix/Aiur/Semantics/SourceEval.lean b/Ix/Aiur/Semantics/SourceEval.lean index d2e6899a..c0cffea7 100644 --- a/Ix/Aiur/Semantics/SourceEval.lean +++ b/Ix/Aiur/Semantics/SourceEval.lean @@ -382,6 +382,14 @@ def interp (decls : Decls) (fuel : Nat) (bindings : Bindings) .field (if x >= 256 then 1 else 0)]) (interp decls fuel bindings t1 st) (fun st1 => interp decls fuel bindings t2 st1) + | .u8Mul t1 t2 => + combineFieldsResult + (fun a b => + let x := a.val.toUInt8.toNat * b.val.toUInt8.toNat + .tuple #[.field (G.ofUInt8 x.toUInt8), + .field (G.ofUInt8 (x / 256).toUInt8)]) + (interp decls fuel bindings t1 st) + (fun st1 => interp decls fuel bindings t2 st1) | .u8Sub t1 t2 => combineFieldsResult (fun a b => diff --git a/Ix/Aiur/Stages/Bytecode.lean b/Ix/Aiur/Stages/Bytecode.lean index acea8020..527e5431 100644 --- a/Ix/Aiur/Stages/Bytecode.lean +++ b/Ix/Aiur/Stages/Bytecode.lean @@ -37,6 +37,7 @@ inductive Op | u8ShiftRight : ValIdx → Op | u8Xor : ValIdx → ValIdx → Op | u8Add : ValIdx → ValIdx → Op + | u8Mul : ValIdx → ValIdx → Op | u8Sub : ValIdx → ValIdx → Op | u8And : ValIdx → ValIdx → Op | u8Or : ValIdx → ValIdx → Op diff --git a/Ix/Aiur/Stages/Concrete.lean b/Ix/Aiur/Stages/Concrete.lean index 1faa134b..3e9eaa50 100644 --- a/Ix/Aiur/Stages/Concrete.lean +++ b/Ix/Aiur/Stages/Concrete.lean @@ -78,6 +78,7 @@ inductive Term : Type where | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term | u8Xor (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Add (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term + | u8Mul (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Sub (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8And (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Or (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term @@ -98,7 +99,7 @@ def Term.typ : Term → Typ | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ | .ioRead t _ _ _ | .ioWrite t _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ - | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Sub t _ _ _ + | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ | .debug t _ _ _ _ => t @@ -114,7 +115,7 @@ def Term.escapes : Term → Bool | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ | .ioRead _ e _ _ | .ioWrite _ e _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ - | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Sub _ e _ _ + | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ | .debug _ e _ _ _ => e diff --git a/Ix/Aiur/Stages/Simple.lean b/Ix/Aiur/Stages/Simple.lean index 2aed45b8..996b5d21 100644 --- a/Ix/Aiur/Stages/Simple.lean +++ b/Ix/Aiur/Stages/Simple.lean @@ -74,6 +74,7 @@ inductive Term : Type where | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term | u8Xor (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Add (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term + | u8Mul (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Sub (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8And (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Or (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term @@ -93,7 +94,7 @@ def Term.typ : Term → Typ | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ | .ioRead t _ _ _ | .ioWrite t _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ - | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Sub t _ _ _ + | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ | .debug t _ _ _ _ => t @@ -108,7 +109,7 @@ def Term.escapes : Term → Bool | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ | .ioRead _ e _ _ | .ioWrite _ e _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ - | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Sub _ e _ _ + | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ | .debug _ e _ _ _ => e diff --git a/Ix/Aiur/Stages/Source.lean b/Ix/Aiur/Stages/Source.lean index 440fa683..f1209021 100644 --- a/Ix/Aiur/Stages/Source.lean +++ b/Ix/Aiur/Stages/Source.lean @@ -382,6 +382,7 @@ inductive Term | u8ShiftRight : Term → Term | u8Xor : Term → Term → Term | u8Add : Term → Term → Term + | u8Mul : Term → Term → Term | u8Sub : Term → Term → Term | u8And : Term → Term → Term | u8Or : Term → Term → Term diff --git a/Ix/Aiur/Stages/Typed.lean b/Ix/Aiur/Stages/Typed.lean index 9ba4dd82..2a3c6f33 100644 --- a/Ix/Aiur/Stages/Typed.lean +++ b/Ix/Aiur/Stages/Typed.lean @@ -50,6 +50,7 @@ inductive Term : Type where | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term | u8Xor (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Add (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term + | u8Mul (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Sub (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8And (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Or (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term @@ -69,7 +70,7 @@ def Term.typ : Term → Typ | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ | .ioRead t _ _ _ | .ioWrite t _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ - | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Sub t _ _ _ + | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ | .debug t _ _ _ _ => t @@ -84,7 +85,7 @@ def Term.escapes : Term → Bool | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ | .ioRead _ e _ _ | .ioWrite _ e _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ - | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Sub _ e _ _ + | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ | .debug _ e _ _ _ => e diff --git a/Tests/Aiur/Aiur.lean b/Tests/Aiur/Aiur.lean index a1a01c7f..6a0913b3 100644 --- a/Tests/Aiur/Aiur.lean +++ b/Tests/Aiur/Aiur.lean @@ -330,6 +330,10 @@ def toplevel := ⟦ u8_sub(i, j) } + pub fn u8_mul_function(i: G, j: G) -> (G, G) { + u8_mul(i, j) + } + pub fn u8_less_than_function(i: G, j: G) -> G { u8_less_than(i, j) } @@ -848,6 +852,7 @@ def aiurTestCases : List AiurTestCase := [ .noIO `shr_shr_shl_decompose #[87] #[0, 1, 0, 1, 0, 1, 0, 0], .noIO `u8_add_xor #[45, 131] #[219, 0, 49, 1], .noIO `u8_sub_function #[45, 131] #[170, 1], + .noIO `u8_mul_function #[45, 131] #[7, 23], .noIO `u8_less_than_function #[45, 131] #[1], .noIO `u8_and_function #[45, 131] #[1], .noIO `u8_or_function #[45, 131] #[175], diff --git a/Tests/Aiur/Cross.lean b/Tests/Aiur/Cross.lean index 4960da32..25d1e988 100644 --- a/Tests/Aiur/Cross.lean +++ b/Tests/Aiur/Cross.lean @@ -375,6 +375,7 @@ def toplevel : Source.Toplevel := ⟦ -- u8 op single-call wrappers pub fn u8_sub_function(i: G, j: G) -> (G, G) { u8_sub(i, j) } + pub fn u8_mul_function(i: G, j: G) -> (G, G) { u8_mul(i, j) } pub fn u8_less_than_function(i: G, j: G) -> G { u8_less_than(i, j) } pub fn u8_and_function(i: G, j: G) -> G { u8_and(i, j) } pub fn u8_or_function(i: G, j: G) -> G { u8_or(i, j) } @@ -1188,6 +1189,8 @@ def tests : TestSeq := runAgreement "assert_eq_trivial" "assert_eq_trivial" [] ++ runAgreement "store_and_load(42)" "store_and_load" [42] ++ runAgreement "u8_sub_function(45,131)" "u8_sub_function" [45, 131] ++ + runAgreement "u8_mul_function(45,131)" "u8_mul_function" [45, 131] ++ + runAgreement "u8_mul_function(255,255)" "u8_mul_function" [255, 255] ++ runAgreement "u8_less_than_function(45,131)" "u8_less_than_function" [45, 131] ++ runAgreement "u8_less_than_function(131,45)" "u8_less_than_function" [131, 45] ++ runAgreement "u8_and_function(45,131)" "u8_and_function" [45, 131] ++ diff --git a/src/aiur.rs b/src/aiur.rs index c4794b85..2e904331 100644 --- a/src/aiur.rs +++ b/src/aiur.rs @@ -69,3 +69,8 @@ pub fn u8_less_than_channel() -> G { pub fn u8_range_check_channel() -> G { G::from_u8(11) } + +#[inline] +pub fn u8_mul_channel() -> G { + G::from_u8(12) +} diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 31d7a88e..10e27cf7 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -52,6 +52,7 @@ pub enum Op { U8ShiftRight(ValIdx), U8Xor(ValIdx, ValIdx), U8Add(ValIdx, ValIdx), + U8Mul(ValIdx, ValIdx), U8Sub(ValIdx, ValIdx), U8And(ValIdx, ValIdx), U8Or(ValIdx, ValIdx), diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index 777eb099..4f0d6c68 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -18,9 +18,9 @@ use crate::{ bytes2::{Bytes2, Bytes2Op}, }, memory_channel, u8_add_channel, u8_and_channel, - u8_bit_decomposition_channel, u8_less_than_channel, u8_or_channel, - u8_range_check_channel, u8_shift_left_channel, u8_shift_right_channel, - u8_sub_channel, u8_xor_channel, + u8_bit_decomposition_channel, u8_less_than_channel, u8_mul_channel, + u8_or_channel, u8_range_check_channel, u8_shift_left_channel, + u8_shift_right_channel, u8_sub_channel, u8_xor_channel, }, }; @@ -492,6 +492,14 @@ impl Op { sel.clone(), state, ), + Op::U8Mul(i, j) => bytes2_constraints( + *i, + *j, + &Bytes2Op::Mul, + u8_mul_channel(), + sel.clone(), + state, + ), Op::U8Sub(i, j) => bytes2_constraints( *i, *j, diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index 9fec900b..9e607ebc 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -363,6 +363,14 @@ impl Function { bytes2_execute(*i, *j, &Bytes2Op::Add, &mut map, record); } }, + ExecEntry::Op(Op::U8Mul(i, j)) => { + if unconstrained { + let (lo, hi) = Bytes2::mul(&map[*i], &map[*j]); + map.extend([lo, hi]); + } else { + bytes2_execute(*i, *j, &Bytes2Op::Mul, &mut map, record); + } + }, ExecEntry::Op(Op::U8Sub(i, j)) => { if unconstrained { let (r, u) = Bytes2::sub(&map[*i], &map[*j]); diff --git a/src/aiur/gadgets/bytes2.rs b/src/aiur/gadgets/bytes2.rs index c3d382c8..e0701f71 100644 --- a/src/aiur/gadgets/bytes2.rs +++ b/src/aiur/gadgets/bytes2.rs @@ -8,8 +8,8 @@ use multi_stark::{ use crate::aiur::{ G, execute::QueryRecord, gadgets::AiurGadget, u8_add_channel, u8_and_channel, - u8_less_than_channel, u8_or_channel, u8_range_check_channel, u8_sub_channel, - u8_xor_channel, + u8_less_than_channel, u8_mul_channel, u8_or_channel, u8_range_check_channel, + u8_sub_channel, u8_xor_channel, }; /// Number of columns in the trace with multiplicities for @@ -20,7 +20,8 @@ use crate::aiur::{ /// - or /// - less_than /// - range_check -const TRACE_WIDTH: usize = 7; +/// - mul +const TRACE_WIDTH: usize = 8; /// Number of columns in the preprocessed trace: /// - first raw byte value @@ -33,7 +34,9 @@ const TRACE_WIDTH: usize = 7; /// - and result /// - or result /// - less_than result -const PREPROCESSED_TRACE_WIDTH: usize = 10; +/// - mul low byte +/// - mul high byte +const PREPROCESSED_TRACE_WIDTH: usize = 12; /// AIR implementer for arity 2 byte-related lookups. pub(crate) struct Bytes2; @@ -41,6 +44,7 @@ pub(crate) struct Bytes2; pub(crate) enum Bytes2Op { Xor, Add, + Mul, Sub, And, Or, @@ -83,6 +87,11 @@ impl BaseAir for Bytes2 { // Less than trace_values.push(G::from_bool(i < j)); + + // Mul (low byte, high byte) + let p = u16::from(i) * u16::from(j); + trace_values.push(G::from_u8((p & 0xff) as u8)); + trace_values.push(G::from_u8((p >> 8) as u8)); } } Some(RowMajorMatrix::new(trace_values, PREPROCESSED_TRACE_WIDTH)) @@ -100,7 +109,7 @@ impl AiurGadget for Bytes2 { fn output_size(&self, op: &Bytes2Op) -> usize { match op { Bytes2Op::Xor | Bytes2Op::And | Bytes2Op::Or | Bytes2Op::LessThan => 1, - Bytes2Op::Add | Bytes2Op::Sub => 2, + Bytes2Op::Add | Bytes2Op::Sub | Bytes2Op::Mul => 2, } } @@ -122,6 +131,11 @@ impl AiurGadget for Bytes2 { let (r, o) = Self::add(i, j); vec![r, o] }, + Bytes2Op::Mul => { + record.bytes2_queries.bump_mul(i, j); + let (lo, hi) = Self::mul(i, j); + vec![lo, hi] + }, Bytes2Op::Sub => { record.bytes2_queries.bump_sub(i, j); let (r, u) = Self::sub(i, j); @@ -151,6 +165,7 @@ impl AiurGadget for Bytes2 { let or_channel = u8_or_channel().into(); let less_than_channel = u8_less_than_channel().into(); let range_check_channel = u8_range_check_channel().into(); + let mul_channel = u8_mul_channel().into(); // Multiplicity columns let xor_multiplicity = var(0); @@ -160,6 +175,7 @@ impl AiurGadget for Bytes2 { let or_multiplicity = var(4); let less_than_multiplicity = var(5); let range_check_multiplicity = var(6); + let mul_multiplicity = var(7); // Preprocessed columns let i = preprocessed_var(0); @@ -172,6 +188,8 @@ impl AiurGadget for Bytes2 { let and = preprocessed_var(7); let or = preprocessed_var(8); let less_than = preprocessed_var(9); + let mul_lo = preprocessed_var(10); + let mul_hi = preprocessed_var(11); let pull_xor = Lookup::pull( xor_multiplicity, @@ -201,6 +219,11 @@ impl AiurGadget for Bytes2 { vec![less_than_channel, i.clone(), j.clone(), less_than], ); + let pull_mul = Lookup::pull( + mul_multiplicity, + vec![mul_channel, i.clone(), j.clone(), mul_lo, mul_hi], + ); + let pull_range_check = Lookup::pull(range_check_multiplicity, vec![range_check_channel, i, j]); @@ -212,6 +235,7 @@ impl AiurGadget for Bytes2 { pull_or, pull_less_than, pull_range_check, + pull_mul, ] } @@ -231,6 +255,7 @@ impl AiurGadget for Bytes2 { let or_channel = u8_or_channel(); let less_than_channel = u8_less_than_channel(); let range_check_channel = u8_range_check_channel(); + let mul_channel = u8_mul_channel(); rows .chunks_exact_mut(TRACE_WIDTH) @@ -239,7 +264,10 @@ impl AiurGadget for Bytes2 { .zip(&mut lookups) .for_each( |( - ((row_idx, row), &[xor, add, sub, and, or, less_than, range_check]), + ( + (row_idx, row), + &[xor, add, sub, and, or, less_than, range_check, mul], + ), row_lookups, )| { let i = G::from_usize(row_idx / 256); @@ -252,6 +280,7 @@ impl AiurGadget for Bytes2 { row[4] = or; row[5] = less_than; row[6] = range_check; + row[7] = mul; // Pull xor. row_lookups[0] = @@ -282,6 +311,10 @@ impl AiurGadget for Bytes2 { // Pull range_check. row_lookups[6] = Lookup::pull(range_check, vec![range_check_channel, i, j]); + + // Pull mul. + let (lo, hi) = Self::mul(&i, &j); + row_lookups[7] = Lookup::pull(mul, vec![mul_channel, i, j, lo, hi]); }, ); (RowMajorMatrix::new(rows, TRACE_WIDTH), lookups) @@ -325,6 +358,10 @@ impl Bytes2Queries { self.bump_multiplicity_for(i, j, 6) } + fn bump_mul(&mut self, i: &G, j: &G) { + self.bump_multiplicity_for(i, j, 7) + } + fn bump_multiplicity_for(&mut self, i: &G, j: &G, col: usize) { let i = usize::try_from(i.as_canonical_u64()).unwrap(); let j = usize::try_from(j.as_canonical_u64()).unwrap(); @@ -377,4 +414,13 @@ impl Bytes2 { let j: u8 = j.as_canonical_u64().try_into().unwrap(); G::from_bool(i < j) } + + /// `u8 * u8 -> (low byte, high byte)`. The product fits in 16 bits. + #[inline] + pub(crate) fn mul(i: &G, j: &G) -> (G, G) { + let i: u8 = i.as_canonical_u64().try_into().unwrap(); + let j: u8 = j.as_canonical_u64().try_into().unwrap(); + let p = u16::from(i) * u16::from(j); + (G::from_u8((p & 0xff) as u8), G::from_u8((p >> 8) as u8)) + } } diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 5cc6f9f6..02d5a928 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -20,9 +20,9 @@ use crate::{ gadgets::{bytes1::Bytes1, bytes2::Bytes2}, memory::Memory, u8_add_channel, u8_and_channel, u8_bit_decomposition_channel, - u8_less_than_channel, u8_or_channel, u8_range_check_channel, - u8_shift_left_channel, u8_shift_right_channel, u8_sub_channel, - u8_xor_channel, + u8_less_than_channel, u8_mul_channel, u8_or_channel, + u8_range_check_channel, u8_shift_left_channel, u8_shift_right_channel, + u8_sub_channel, u8_xor_channel, }, }; @@ -426,6 +426,17 @@ impl Op { let lookup_args = vec![u8_add_channel(), i, j, r, o]; slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); }, + Op::U8Mul(i, j) => { + let (i, _) = map[*i]; + let (j, _) = map[*j]; + let (lo, hi) = Bytes2::mul(&i, &j); + map.push((lo, 1)); + map.push((hi, 1)); + slice.push_auxiliary(index, lo); + slice.push_auxiliary(index, hi); + let lookup_args = vec![u8_mul_channel(), i, j, lo, hi]; + slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); + }, Op::U8Sub(i, j) => { let (i, _) = map[*i]; let (j, _) = map[*j]; diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index 02b8a4e8..e1f0df73 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -107,25 +107,29 @@ fn decode_op(ctor: LeanCtor>) -> Op { }, 18 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U8Sub(i, j) + Op::U8Mul(i, j) }, 19 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U8And(i, j) + Op::U8Sub(i, j) }, 20 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U8Or(i, j) + Op::U8And(i, j) }, 21 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U8LessThan(i, j) + Op::U8Or(i, j) }, 22 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U32LessThan(i, j) + Op::U8LessThan(i, j) }, 23 => { + let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); + Op::U32LessThan(i, j) + }, + 24 => { let [label_obj, idxs_obj] = ctor.objs::<2>(); let label = label_obj.as_string().to_string(); let idxs = if idxs_obj.is_scalar() { From 091881b3b6c9c72579bb13cf417ec0abc6a4be61 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Fri, 15 May 2026 10:53:44 -0700 Subject: [PATCH 3/5] perf(kernel): use u8_mul gadget in u64_mul, shrink divmod_256 u64_mul's byte schoolbook now splits each byte product via the u8_mul gadget instead of a field mul, so column accumulators are sums of bytes (< ~4096) rather than sums of products (< ~520k). divmod_256 therefore carry-propagates only small values: for Nat.mul 1000000 1000000 its trace height drops 86 -> 8 (FFT -29k). u64_mul itself widens (186 -> 506) from the u8_mul lookups -- a fixed +320 -- so the net is a win for any non-trivial multiplication. Adds IxVMPrim.nat_mul_big as a multi-byte mul check target. --- Ix/IxVM/Kernel/Primitive.lean | 141 +++++++++++++++++++++++++--------- Tests/Ix/IxVM.lean | 2 + 2 files changed, 107 insertions(+), 36 deletions(-) diff --git a/Ix/IxVM/Kernel/Primitive.lean b/Ix/IxVM/Kernel/Primitive.lean index 383da8b7..9758d47f 100644 --- a/Ix/IxVM/Kernel/Primitive.lean +++ b/Ix/IxVM/Kernel/Primitive.lean @@ -873,10 +873,9 @@ def primitive := ⟦ klimbs_sub(a, one) } - -- TODO(u8_mul_gadget): replace `divmod_256` + byte-schoolbook `u64_mul` - -- with a proper u8_mul Aiur gadget once it lands. Tracking on a separate - -- branch. - -- Returns (remainder, quotient) where remainder = x mod 256, quotient = x / 256. + -- Returns (remainder, quotient): remainder = x mod 256, quotient = x / 256. + -- `u64_mul` feeds this only small column sums (sums of bytes, < ~4096), + -- so the repeated subtraction terminates in a handful of steps. fn divmod_256(x: G, q: G) -> (G, G) { match u32_less_than(x, 256) { 1 => (x, q), @@ -884,56 +883,126 @@ def primitive := ⟦ } } - -- u64×u64 → (lo: U64, hi: U64) via byte schoolbook. + -- u64×u64 → (lo: U64, hi: U64) via byte schoolbook. Each byte×byte + -- product is split into (low, high) bytes by the `u8_mul` gadget, so + -- every column is a sum of bytes (not products) and `divmod_256` only + -- carry-propagates small values. fn u64_mul(a: U64, b: U64) -> (U64, U64) { let [a0, a1, a2, a3, a4, a5, a6, a7] = a; let [b0, b1, b2, b3, b4, b5, b6, b7] = b; - let pp0 = a0*b0; - let pp1 = a0*b1 + a1*b0; - let pp2 = a0*b2 + a1*b1 + a2*b0; - let pp3 = a0*b3 + a1*b2 + a2*b1 + a3*b0; - let pp4 = a0*b4 + a1*b3 + a2*b2 + a3*b1 + a4*b0; - let pp5 = a0*b5 + a1*b4 + a2*b3 + a3*b2 + a4*b1 + a5*b0; - let pp6 = a0*b6 + a1*b5 + a2*b4 + a3*b3 + a4*b2 + a5*b1 + a6*b0; - let pp7 = a0*b7 + a1*b6 + a2*b5 + a3*b4 + a4*b3 + a5*b2 + a6*b1 + a7*b0; - let pp8 = a1*b7 + a2*b6 + a3*b5 + a4*b4 + a5*b3 + a6*b2 + a7*b1; - let pp9 = a2*b7 + a3*b6 + a4*b5 + a5*b4 + a6*b3 + a7*b2; - let pp10 = a3*b7 + a4*b6 + a5*b5 + a6*b4 + a7*b3; - let pp11 = a4*b7 + a5*b6 + a6*b5 + a7*b4; - let pp12 = a5*b7 + a6*b6 + a7*b5; - let pp13 = a6*b7 + a7*b6; - let pp14 = a7*b7; - match divmod_256(pp0, 0) { + let (l00, h00) = u8_mul(a0, b0); + let (l01, h01) = u8_mul(a0, b1); + let (l02, h02) = u8_mul(a0, b2); + let (l03, h03) = u8_mul(a0, b3); + let (l04, h04) = u8_mul(a0, b4); + let (l05, h05) = u8_mul(a0, b5); + let (l06, h06) = u8_mul(a0, b6); + let (l07, h07) = u8_mul(a0, b7); + let (l10, h10) = u8_mul(a1, b0); + let (l11, h11) = u8_mul(a1, b1); + let (l12, h12) = u8_mul(a1, b2); + let (l13, h13) = u8_mul(a1, b3); + let (l14, h14) = u8_mul(a1, b4); + let (l15, h15) = u8_mul(a1, b5); + let (l16, h16) = u8_mul(a1, b6); + let (l17, h17) = u8_mul(a1, b7); + let (l20, h20) = u8_mul(a2, b0); + let (l21, h21) = u8_mul(a2, b1); + let (l22, h22) = u8_mul(a2, b2); + let (l23, h23) = u8_mul(a2, b3); + let (l24, h24) = u8_mul(a2, b4); + let (l25, h25) = u8_mul(a2, b5); + let (l26, h26) = u8_mul(a2, b6); + let (l27, h27) = u8_mul(a2, b7); + let (l30, h30) = u8_mul(a3, b0); + let (l31, h31) = u8_mul(a3, b1); + let (l32, h32) = u8_mul(a3, b2); + let (l33, h33) = u8_mul(a3, b3); + let (l34, h34) = u8_mul(a3, b4); + let (l35, h35) = u8_mul(a3, b5); + let (l36, h36) = u8_mul(a3, b6); + let (l37, h37) = u8_mul(a3, b7); + let (l40, h40) = u8_mul(a4, b0); + let (l41, h41) = u8_mul(a4, b1); + let (l42, h42) = u8_mul(a4, b2); + let (l43, h43) = u8_mul(a4, b3); + let (l44, h44) = u8_mul(a4, b4); + let (l45, h45) = u8_mul(a4, b5); + let (l46, h46) = u8_mul(a4, b6); + let (l47, h47) = u8_mul(a4, b7); + let (l50, h50) = u8_mul(a5, b0); + let (l51, h51) = u8_mul(a5, b1); + let (l52, h52) = u8_mul(a5, b2); + let (l53, h53) = u8_mul(a5, b3); + let (l54, h54) = u8_mul(a5, b4); + let (l55, h55) = u8_mul(a5, b5); + let (l56, h56) = u8_mul(a5, b6); + let (l57, h57) = u8_mul(a5, b7); + let (l60, h60) = u8_mul(a6, b0); + let (l61, h61) = u8_mul(a6, b1); + let (l62, h62) = u8_mul(a6, b2); + let (l63, h63) = u8_mul(a6, b3); + let (l64, h64) = u8_mul(a6, b4); + let (l65, h65) = u8_mul(a6, b5); + let (l66, h66) = u8_mul(a6, b6); + let (l67, h67) = u8_mul(a6, b7); + let (l70, h70) = u8_mul(a7, b0); + let (l71, h71) = u8_mul(a7, b1); + let (l72, h72) = u8_mul(a7, b2); + let (l73, h73) = u8_mul(a7, b3); + let (l74, h74) = u8_mul(a7, b4); + let (l75, h75) = u8_mul(a7, b5); + let (l76, h76) = u8_mul(a7, b6); + let (l77, h77) = u8_mul(a7, b7); + -- Column k gathers low bytes of products with i+j=k and high bytes + -- of products with i+j=k-1. + let col0 = l00; + let col1 = l01 + l10 + h00; + let col2 = l02 + l11 + l20 + h01 + h10; + let col3 = l03 + l12 + l21 + l30 + h02 + h11 + h20; + let col4 = l04 + l13 + l22 + l31 + l40 + h03 + h12 + h21 + h30; + let col5 = l05 + l14 + l23 + l32 + l41 + l50 + h04 + h13 + h22 + h31 + h40; + let col6 = l06 + l15 + l24 + l33 + l42 + l51 + l60 + h05 + h14 + h23 + h32 + h41 + h50; + let col7 = l07 + l16 + l25 + l34 + l43 + l52 + l61 + l70 + h06 + h15 + h24 + h33 + h42 + h51 + h60; + let col8 = l17 + l26 + l35 + l44 + l53 + l62 + l71 + h07 + h16 + h25 + h34 + h43 + h52 + h61 + h70; + let col9 = l27 + l36 + l45 + l54 + l63 + l72 + h17 + h26 + h35 + h44 + h53 + h62 + h71; + let col10 = l37 + l46 + l55 + l64 + l73 + h27 + h36 + h45 + h54 + h63 + h72; + let col11 = l47 + l56 + l65 + l74 + h37 + h46 + h55 + h64 + h73; + let col12 = l57 + l66 + l75 + h47 + h56 + h65 + h74; + let col13 = l67 + l76 + h57 + h66 + h75; + let col14 = l77 + h67 + h76; + let col15 = h77; + match divmod_256(col0, 0) { (r0, c1) => - match divmod_256(pp1 + c1, 0) { + match divmod_256(col1 + c1, 0) { (r1, c2) => - match divmod_256(pp2 + c2, 0) { + match divmod_256(col2 + c2, 0) { (r2, c3) => - match divmod_256(pp3 + c3, 0) { + match divmod_256(col3 + c3, 0) { (r3, c4) => - match divmod_256(pp4 + c4, 0) { + match divmod_256(col4 + c4, 0) { (r4, c5) => - match divmod_256(pp5 + c5, 0) { + match divmod_256(col5 + c5, 0) { (r5, c6) => - match divmod_256(pp6 + c6, 0) { + match divmod_256(col6 + c6, 0) { (r6, c7) => - match divmod_256(pp7 + c7, 0) { + match divmod_256(col7 + c7, 0) { (r7, c8) => - match divmod_256(pp8 + c8, 0) { + match divmod_256(col8 + c8, 0) { (r8, c9) => - match divmod_256(pp9 + c9, 0) { + match divmod_256(col9 + c9, 0) { (r9, c10) => - match divmod_256(pp10 + c10, 0) { + match divmod_256(col10 + c10, 0) { (r10, c11) => - match divmod_256(pp11 + c11, 0) { + match divmod_256(col11 + c11, 0) { (r11, c12) => - match divmod_256(pp12 + c12, 0) { + match divmod_256(col12 + c12, 0) { (r12, c13) => - match divmod_256(pp13 + c13, 0) { + match divmod_256(col13 + c13, 0) { (r13, c14) => - match divmod_256(pp14 + c14, 0) { + match divmod_256(col14 + c14, 0) { (r14, c15) => - match divmod_256(c15, 0) { + match divmod_256(col15 + c15, 0) { (r15, _) => ([r0, r1, r2, r3, r4, r5, r6, r7], [r8, r9, r10, r11, r12, r13, r14, r15]), diff --git a/Tests/Ix/IxVM.lean b/Tests/Ix/IxVM.lean index ef1df320..974b4828 100644 --- a/Tests/Ix/IxVM.lean +++ b/Tests/Ix/IxVM.lean @@ -29,6 +29,7 @@ namespace IxVMPrim public theorem nat_add_lit : 100 + 200 = 300 := rfl public theorem nat_sub_lit : 1000 - 250 = 750 := rfl public theorem nat_mul_lit : Nat.mul 6 7 = 42 := rfl +public theorem nat_mul_big : Nat.mul 1000000 1000000 = 1000000000000 := rfl public theorem nat_div_lit : Nat.div 100 7 = 14 := rfl public theorem nat_mod_lit : Nat.mod 100 7 = 2 := rfl public theorem nat_succ_lit : Nat.succ 41 = 42 := rfl @@ -111,6 +112,7 @@ private def kernelCheckNames : List String := [ "Nat.sub_le_of_le_add", -- Primitive reduction theorems (`IxVMPrim`) "IxVMPrim.nat_add_lit", "IxVMPrim.nat_sub_lit", "IxVMPrim.nat_mul_lit", + "IxVMPrim.nat_mul_big", "IxVMPrim.nat_div_lit", "IxVMPrim.nat_mod_lit", "IxVMPrim.nat_succ_lit", "IxVMPrim.nat_pred_lit", "IxVMPrim.nat_gcd_lit", "IxVMPrim.nat_land_lit", "IxVMPrim.nat_lor_lit", "IxVMPrim.nat_xor_lit", From 2152020c5edc546131205fb5706cb4230e775281 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Sat, 16 May 2026 09:27:20 -0700 Subject: [PATCH 4/5] perf(kernel): u64_mul via unconstrained carry witness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the u8_mul-gadget schoolbook in u64_mul with the MulWitness/Product reference algorithm: column k is the raw field sum Σ_{i+j=k} a[i]*b[j], and each column accumulator is decomposed by a prover-provided (#-unconstrained) split into result byte + 16-bit carry, pinned by three u8 range checks and a reconstruction assert out == limb + 256·clo + 65536·chi. The division moves off-circuit: split_carry -> divmod_256 is reached only through the unconstrained call, so it costs zero trace rows. u64_mul drops to a constant 360 FFT, operand-independent. --- Ix/IxVM/Kernel/Primitive.lean | 269 ++++++++++++++++++---------------- 1 file changed, 146 insertions(+), 123 deletions(-) diff --git a/Ix/IxVM/Kernel/Primitive.lean b/Ix/IxVM/Kernel/Primitive.lean index 9758d47f..5a8e6231 100644 --- a/Ix/IxVM/Kernel/Primitive.lean +++ b/Ix/IxVM/Kernel/Primitive.lean @@ -874,8 +874,8 @@ def primitive := ⟦ } -- Returns (remainder, quotient): remainder = x mod 256, quotient = x / 256. - -- `u64_mul` feeds this only small column sums (sums of bytes, < ~4096), - -- so the repeated subtraction terminates in a handful of steps. + -- Repeated subtraction — callers feed it small `x`, or call it only from + -- unconstrained code where the iteration cost is off-circuit. fn divmod_256(x: G, q: G) -> (G, G) { match u32_less_than(x, 256) { 1 => (x, q), @@ -883,130 +883,153 @@ def primitive := ⟦ } } - -- u64×u64 → (lo: U64, hi: U64) via byte schoolbook. Each byte×byte - -- product is split into (low, high) bytes by the `u8_mul` gadget, so - -- every column is a sum of bytes (not products) and `divmod_256` only - -- carry-propagates small values. + -- Unconstrained witness generator: split `x` into its low byte `limb` + -- and the two bytes (clo, chi) of `x div 256`. Always invoked as + -- `#split_carry(...)`; the result is prover-provided and MUST be pinned + -- by the caller with u8 range checks + a reconstruction assert. The + -- division here is off-circuit (untraced), so its cost is irrelevant. + fn split_carry(x: G) -> (G, G, G) { + match divmod_256(x, 0) { + (limb, quot) => + match divmod_256(quot, 0) { + (clo, chi) => (limb, clo, chi), + }, + } + } + + -- u64×u64 → (lo: U64, hi: U64) via byte schoolbook. Faithful port of the + -- `MulWitness`/`Product` reference: column k is the raw field sum + -- Σ_{i+j=k} a[i]*b[j]; each column accumulator `out` is decomposed by a + -- prover-provided (unconstrained) split into result byte + 16-bit carry, + -- then pinned by three u8 range checks (`u8_xor(_, 0)`) and the + -- reconstruction assert `out == limb + 256·clo + 65536·chi`. No `u8_mul` + -- gadget and no constrained division. Column accumulators are < 2^19, so + -- the decomposition into (limb, clo, chi) ∈ [0,256)³ is unique → sound. fn u64_mul(a: U64, b: U64) -> (U64, U64) { let [a0, a1, a2, a3, a4, a5, a6, a7] = a; let [b0, b1, b2, b3, b4, b5, b6, b7] = b; - let (l00, h00) = u8_mul(a0, b0); - let (l01, h01) = u8_mul(a0, b1); - let (l02, h02) = u8_mul(a0, b2); - let (l03, h03) = u8_mul(a0, b3); - let (l04, h04) = u8_mul(a0, b4); - let (l05, h05) = u8_mul(a0, b5); - let (l06, h06) = u8_mul(a0, b6); - let (l07, h07) = u8_mul(a0, b7); - let (l10, h10) = u8_mul(a1, b0); - let (l11, h11) = u8_mul(a1, b1); - let (l12, h12) = u8_mul(a1, b2); - let (l13, h13) = u8_mul(a1, b3); - let (l14, h14) = u8_mul(a1, b4); - let (l15, h15) = u8_mul(a1, b5); - let (l16, h16) = u8_mul(a1, b6); - let (l17, h17) = u8_mul(a1, b7); - let (l20, h20) = u8_mul(a2, b0); - let (l21, h21) = u8_mul(a2, b1); - let (l22, h22) = u8_mul(a2, b2); - let (l23, h23) = u8_mul(a2, b3); - let (l24, h24) = u8_mul(a2, b4); - let (l25, h25) = u8_mul(a2, b5); - let (l26, h26) = u8_mul(a2, b6); - let (l27, h27) = u8_mul(a2, b7); - let (l30, h30) = u8_mul(a3, b0); - let (l31, h31) = u8_mul(a3, b1); - let (l32, h32) = u8_mul(a3, b2); - let (l33, h33) = u8_mul(a3, b3); - let (l34, h34) = u8_mul(a3, b4); - let (l35, h35) = u8_mul(a3, b5); - let (l36, h36) = u8_mul(a3, b6); - let (l37, h37) = u8_mul(a3, b7); - let (l40, h40) = u8_mul(a4, b0); - let (l41, h41) = u8_mul(a4, b1); - let (l42, h42) = u8_mul(a4, b2); - let (l43, h43) = u8_mul(a4, b3); - let (l44, h44) = u8_mul(a4, b4); - let (l45, h45) = u8_mul(a4, b5); - let (l46, h46) = u8_mul(a4, b6); - let (l47, h47) = u8_mul(a4, b7); - let (l50, h50) = u8_mul(a5, b0); - let (l51, h51) = u8_mul(a5, b1); - let (l52, h52) = u8_mul(a5, b2); - let (l53, h53) = u8_mul(a5, b3); - let (l54, h54) = u8_mul(a5, b4); - let (l55, h55) = u8_mul(a5, b5); - let (l56, h56) = u8_mul(a5, b6); - let (l57, h57) = u8_mul(a5, b7); - let (l60, h60) = u8_mul(a6, b0); - let (l61, h61) = u8_mul(a6, b1); - let (l62, h62) = u8_mul(a6, b2); - let (l63, h63) = u8_mul(a6, b3); - let (l64, h64) = u8_mul(a6, b4); - let (l65, h65) = u8_mul(a6, b5); - let (l66, h66) = u8_mul(a6, b6); - let (l67, h67) = u8_mul(a6, b7); - let (l70, h70) = u8_mul(a7, b0); - let (l71, h71) = u8_mul(a7, b1); - let (l72, h72) = u8_mul(a7, b2); - let (l73, h73) = u8_mul(a7, b3); - let (l74, h74) = u8_mul(a7, b4); - let (l75, h75) = u8_mul(a7, b5); - let (l76, h76) = u8_mul(a7, b6); - let (l77, h77) = u8_mul(a7, b7); - -- Column k gathers low bytes of products with i+j=k and high bytes - -- of products with i+j=k-1. - let col0 = l00; - let col1 = l01 + l10 + h00; - let col2 = l02 + l11 + l20 + h01 + h10; - let col3 = l03 + l12 + l21 + l30 + h02 + h11 + h20; - let col4 = l04 + l13 + l22 + l31 + l40 + h03 + h12 + h21 + h30; - let col5 = l05 + l14 + l23 + l32 + l41 + l50 + h04 + h13 + h22 + h31 + h40; - let col6 = l06 + l15 + l24 + l33 + l42 + l51 + l60 + h05 + h14 + h23 + h32 + h41 + h50; - let col7 = l07 + l16 + l25 + l34 + l43 + l52 + l61 + l70 + h06 + h15 + h24 + h33 + h42 + h51 + h60; - let col8 = l17 + l26 + l35 + l44 + l53 + l62 + l71 + h07 + h16 + h25 + h34 + h43 + h52 + h61 + h70; - let col9 = l27 + l36 + l45 + l54 + l63 + l72 + h17 + h26 + h35 + h44 + h53 + h62 + h71; - let col10 = l37 + l46 + l55 + l64 + l73 + h27 + h36 + h45 + h54 + h63 + h72; - let col11 = l47 + l56 + l65 + l74 + h37 + h46 + h55 + h64 + h73; - let col12 = l57 + l66 + l75 + h47 + h56 + h65 + h74; - let col13 = l67 + l76 + h57 + h66 + h75; - let col14 = l77 + h67 + h76; - let col15 = h77; - match divmod_256(col0, 0) { - (r0, c1) => - match divmod_256(col1 + c1, 0) { - (r1, c2) => - match divmod_256(col2 + c2, 0) { - (r2, c3) => - match divmod_256(col3 + c3, 0) { - (r3, c4) => - match divmod_256(col4 + c4, 0) { - (r4, c5) => - match divmod_256(col5 + c5, 0) { - (r5, c6) => - match divmod_256(col6 + c6, 0) { - (r6, c7) => - match divmod_256(col7 + c7, 0) { - (r7, c8) => - match divmod_256(col8 + c8, 0) { - (r8, c9) => - match divmod_256(col9 + c9, 0) { - (r9, c10) => - match divmod_256(col10 + c10, 0) { - (r10, c11) => - match divmod_256(col11 + c11, 0) { - (r11, c12) => - match divmod_256(col12 + c12, 0) { - (r12, c13) => - match divmod_256(col13 + c13, 0) { - (r13, c14) => - match divmod_256(col14 + c14, 0) { - (r14, c15) => - match divmod_256(col15 + c15, 0) { - (r15, _) => - ([r0, r1, r2, r3, r4, r5, r6, r7], - [r8, r9, r10, r11, r12, r13, r14, r15]), - }, + let col0 = (a0 * b0); + let col1 = (a0 * b1) + (a1 * b0); + let col2 = (a0 * b2) + (a1 * b1) + (a2 * b0); + let col3 = (a0 * b3) + (a1 * b2) + (a2 * b1) + (a3 * b0); + let col4 = (a0 * b4) + (a1 * b3) + (a2 * b2) + (a3 * b1) + (a4 * b0); + let col5 = (a0 * b5) + (a1 * b4) + (a2 * b3) + (a3 * b2) + (a4 * b1) + (a5 * b0); + let col6 = (a0 * b6) + (a1 * b5) + (a2 * b4) + (a3 * b3) + (a4 * b2) + (a5 * b1) + (a6 * b0); + let col7 = (a0 * b7) + (a1 * b6) + (a2 * b5) + (a3 * b4) + (a4 * b3) + (a5 * b2) + (a6 * b1) + (a7 * b0); + let col8 = (a1 * b7) + (a2 * b6) + (a3 * b5) + (a4 * b4) + (a5 * b3) + (a6 * b2) + (a7 * b1); + let col9 = (a2 * b7) + (a3 * b6) + (a4 * b5) + (a5 * b4) + (a6 * b3) + (a7 * b2); + let col10 = (a3 * b7) + (a4 * b6) + (a5 * b5) + (a6 * b4) + (a7 * b3); + let col11 = (a4 * b7) + (a5 * b6) + (a6 * b5) + (a7 * b4); + let col12 = (a5 * b7) + (a6 * b6) + (a7 * b5); + let col13 = (a6 * b7) + (a7 * b6); + let col14 = (a7 * b7); + match #split_carry(col0) { + (rl0, rc0, rh0) => + let r0 = u8_xor(rl0, 0); + let lo0 = u8_xor(rc0, 0); + let hi0 = u8_xor(rh0, 0); + assert_eq!(col0, r0 + (256 * lo0) + (65536 * hi0)); + let out1 = col1 + lo0 + (256 * hi0); + match #split_carry(out1) { + (rl1, rc1, rh1) => + let r1 = u8_xor(rl1, 0); + let lo1 = u8_xor(rc1, 0); + let hi1 = u8_xor(rh1, 0); + assert_eq!(out1, r1 + (256 * lo1) + (65536 * hi1)); + let out2 = col2 + lo1 + (256 * hi1); + match #split_carry(out2) { + (rl2, rc2, rh2) => + let r2 = u8_xor(rl2, 0); + let lo2 = u8_xor(rc2, 0); + let hi2 = u8_xor(rh2, 0); + assert_eq!(out2, r2 + (256 * lo2) + (65536 * hi2)); + let out3 = col3 + lo2 + (256 * hi2); + match #split_carry(out3) { + (rl3, rc3, rh3) => + let r3 = u8_xor(rl3, 0); + let lo3 = u8_xor(rc3, 0); + let hi3 = u8_xor(rh3, 0); + assert_eq!(out3, r3 + (256 * lo3) + (65536 * hi3)); + let out4 = col4 + lo3 + (256 * hi3); + match #split_carry(out4) { + (rl4, rc4, rh4) => + let r4 = u8_xor(rl4, 0); + let lo4 = u8_xor(rc4, 0); + let hi4 = u8_xor(rh4, 0); + assert_eq!(out4, r4 + (256 * lo4) + (65536 * hi4)); + let out5 = col5 + lo4 + (256 * hi4); + match #split_carry(out5) { + (rl5, rc5, rh5) => + let r5 = u8_xor(rl5, 0); + let lo5 = u8_xor(rc5, 0); + let hi5 = u8_xor(rh5, 0); + assert_eq!(out5, r5 + (256 * lo5) + (65536 * hi5)); + let out6 = col6 + lo5 + (256 * hi5); + match #split_carry(out6) { + (rl6, rc6, rh6) => + let r6 = u8_xor(rl6, 0); + let lo6 = u8_xor(rc6, 0); + let hi6 = u8_xor(rh6, 0); + assert_eq!(out6, r6 + (256 * lo6) + (65536 * hi6)); + let out7 = col7 + lo6 + (256 * hi6); + match #split_carry(out7) { + (rl7, rc7, rh7) => + let r7 = u8_xor(rl7, 0); + let lo7 = u8_xor(rc7, 0); + let hi7 = u8_xor(rh7, 0); + assert_eq!(out7, r7 + (256 * lo7) + (65536 * hi7)); + let out8 = col8 + lo7 + (256 * hi7); + match #split_carry(out8) { + (rl8, rc8, rh8) => + let r8 = u8_xor(rl8, 0); + let lo8 = u8_xor(rc8, 0); + let hi8 = u8_xor(rh8, 0); + assert_eq!(out8, r8 + (256 * lo8) + (65536 * hi8)); + let out9 = col9 + lo8 + (256 * hi8); + match #split_carry(out9) { + (rl9, rc9, rh9) => + let r9 = u8_xor(rl9, 0); + let lo9 = u8_xor(rc9, 0); + let hi9 = u8_xor(rh9, 0); + assert_eq!(out9, r9 + (256 * lo9) + (65536 * hi9)); + let out10 = col10 + lo9 + (256 * hi9); + match #split_carry(out10) { + (rl10, rc10, rh10) => + let r10 = u8_xor(rl10, 0); + let lo10 = u8_xor(rc10, 0); + let hi10 = u8_xor(rh10, 0); + assert_eq!(out10, r10 + (256 * lo10) + (65536 * hi10)); + let out11 = col11 + lo10 + (256 * hi10); + match #split_carry(out11) { + (rl11, rc11, rh11) => + let r11 = u8_xor(rl11, 0); + let lo11 = u8_xor(rc11, 0); + let hi11 = u8_xor(rh11, 0); + assert_eq!(out11, r11 + (256 * lo11) + (65536 * hi11)); + let out12 = col12 + lo11 + (256 * hi11); + match #split_carry(out12) { + (rl12, rc12, rh12) => + let r12 = u8_xor(rl12, 0); + let lo12 = u8_xor(rc12, 0); + let hi12 = u8_xor(rh12, 0); + assert_eq!(out12, r12 + (256 * lo12) + (65536 * hi12)); + let out13 = col13 + lo12 + (256 * hi12); + match #split_carry(out13) { + (rl13, rc13, rh13) => + let r13 = u8_xor(rl13, 0); + let lo13 = u8_xor(rc13, 0); + let hi13 = u8_xor(rh13, 0); + assert_eq!(out13, r13 + (256 * lo13) + (65536 * hi13)); + let out14 = col14 + lo13 + (256 * hi13); + match #split_carry(out14) { + (rl14, rc14, rh14) => + let r14 = u8_xor(rl14, 0); + let lo14 = u8_xor(rc14, 0); + let hi14 = u8_xor(rh14, 0); + assert_eq!(out14, r14 + (256 * lo14) + (65536 * hi14)); + let r15 = lo14 + (256 * hi14); + ([r0, r1, r2, r3, r4, r5, r6, r7], + [r8, r9, r10, r11, r12, r13, r14, r15]), }, }, }, From 11f6f4570b24e2615a6198f8c0455d880e31f396 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Sat, 16 May 2026 09:50:36 -0700 Subject: [PATCH 5/5] perf(kernel): klimbs_from_g via unconstrained byte witness Replace klimbs_from_g's 4-deep constrained divmod_256 chain with a prover-provided (#-unconstrained) #split_u32, pinned by four u8 range checks + the reconstruction assert x == b0 + 256*b1 + 65536*b2 + 16777216*b3. The byte decomposition moves off-circuit: divmod_256 is now reached only through #split_carry / #split_u32, so its repeated-subtraction cost is fully untraced. x >= 2^32 is rejected (assert) rather than silently truncated; callers only pass UTF-8 codepoints. --- Ix/IxVM/Kernel/Primitive.lean | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/Ix/IxVM/Kernel/Primitive.lean b/Ix/IxVM/Kernel/Primitive.lean index 5a8e6231..abb6c0a0 100644 --- a/Ix/IxVM/Kernel/Primitive.lean +++ b/Ix/IxVM/Kernel/Primitive.lean @@ -874,8 +874,9 @@ def primitive := ⟦ } -- Returns (remainder, quotient): remainder = x mod 256, quotient = x / 256. - -- Repeated subtraction — callers feed it small `x`, or call it only from - -- unconstrained code where the iteration cost is off-circuit. + -- Repeated subtraction. Only ever invoked from the `#split_carry` / + -- `#split_u32` unconstrained witness generators, so the O(x/256) + -- iteration cost is off-circuit (untraced). fn divmod_256(x: G, q: G) -> (G, G) { match u32_less_than(x, 256) { 1 => (x, q), @@ -2192,8 +2193,11 @@ def primitive := ⟦ } } - -- Convert G value (≤ 2^32) into single-limb KLimbs via byte decomp. - fn klimbs_from_g(x: G) -> KLimbs { + -- Unconstrained witness generator: split `x` (< 2^32) into 4 + -- little-endian bytes. Always invoked as `#split_u32(...)`; the result + -- is prover-provided and MUST be pinned by the caller with u8 range + -- checks + a reconstruction assert. Division is off-circuit (untraced). + fn split_u32(x: G) -> (G, G, G, G) { match divmod_256(x, 0) { (b0, q1) => match divmod_256(q1, 0) { @@ -2201,15 +2205,30 @@ def primitive := ⟦ match divmod_256(q2, 0) { (b2, q3) => match divmod_256(q3, 0) { - (b3, _q4) => - store(ListNode.Cons([b0, b1, b2, b3, 0, 0, 0, 0], - store(ListNode.Nil))), + (b3, _) => (b0, b1, b2, b3), }, }, }, } } + -- Convert G value (< 2^32) into single-limb KLimbs. The 4-byte + -- decomposition is a prover-provided (unconstrained) witness, pinned by + -- four u8 range checks + the reconstruction assert. `x >= 2^32` is + -- rejected (assert fails) rather than silently truncated. + fn klimbs_from_g(x: G) -> KLimbs { + match #split_u32(x) { + (rb0, rb1, rb2, rb3) => + let b0 = u8_xor(rb0, 0); + let b1 = u8_xor(rb1, 0); + let b2 = u8_xor(rb2, 0); + let b3 = u8_xor(rb3, 0); + assert_eq!(x, b0 + (256 * b1) + (65536 * b2) + (16777216 * b3)); + store(ListNode.Cons([b0, b1, b2, b3, 0, 0, 0, 0], + store(ListNode.Nil))), + } + } + -- Walk byte stream forward decoding UTF-8 codepoints; return last. -- Empty → 65 ('A') per Rust default. fn utf8_last_codepoint(bs: ByteStream) -> G {