diff --git a/Ix/Aiur/Compiler/Check.lean b/Ix/Aiur/Compiler/Check.lean index 54a9635c..6970abbb 100644 --- a/Ix/Aiur/Compiler/Check.lean +++ b/Ix/Aiur/Compiler/Check.lean @@ -735,6 +735,10 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with let a' ← checkNoEscape a .field let b' ← checkNoEscape b .field pure (Typed.Term.u8Add (.tuple #[.field, .field]) false a' b') + | .u8Mul a b => do + let a' ← checkNoEscape a .field + let b' ← checkNoEscape b .field + pure (Typed.Term.u8Mul (.tuple #[.field, .field]) false a' b') | .u8Sub a b => do let a' ← checkNoEscape a .field let b' ← checkNoEscape b .field @@ -900,6 +904,7 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with | .u8ShiftRight τ e a => do pure (.u8ShiftRight (← zonkTyp τ) e (← zonkTypedTerm a)) | .u8Xor τ e a b => do pure (.u8Xor (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) | .u8Add τ e a b => do pure (.u8Add (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) + | .u8Mul τ e a b => do pure (.u8Mul (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) | .u8Sub τ e a b => do pure (.u8Sub (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) | .u8And τ e a b => do pure (.u8And (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) | .u8Or τ e a b => do pure (.u8Or (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b)) diff --git a/Ix/Aiur/Compiler/Concretize.lean b/Ix/Aiur/Compiler/Concretize.lean index 152ceb72..b9924aee 100644 --- a/Ix/Aiur/Compiler/Concretize.lean +++ b/Ix/Aiur/Compiler/Concretize.lean @@ -325,6 +325,9 @@ def termToConcrete | .u8Add τ e a b => do pure (.u8Add (← typToConcrete mono τ) e (← termToConcrete mono a) (← termToConcrete mono b)) + | .u8Mul τ e a b => do + pure (.u8Mul (← typToConcrete mono τ) e + (← termToConcrete mono a) (← termToConcrete mono b)) | .u8Sub τ e a b => do pure (.u8Sub (← typToConcrete mono τ) e (← termToConcrete mono a) (← termToConcrete mono b)) @@ -521,6 +524,8 @@ def rewriteTypedTerm (decls : Typed.Decls) (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) | .u8Add τ e a b => .u8Add (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) + | .u8Mul τ e a b => .u8Mul (rewriteTyp subst mono τ) e + (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) | .u8Sub τ e a b => .u8Sub (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) | .u8And τ e a b => .u8And (rewriteTyp subst mono τ) e @@ -597,7 +602,7 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) : let seen := tArgs.foldl collectInTyp seen args.attach.foldl (fun s ⟨a, _⟩ => collectInTypedTerm s a) seen | .add τ _ a b | .sub τ _ a b | .mul τ _ a b - | .u8Xor τ _ a b | .u8Add τ _ a b | .u8Sub τ _ a b + | .u8Xor τ _ a b | .u8Add τ _ a b | .u8Mul τ _ a b | .u8Sub τ _ a b | .u8And τ _ a b | .u8Or τ _ a b | .u8LessThan τ _ a b | .u32LessThan τ _ a b => collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) b @@ -659,7 +664,7 @@ def collectCalls (decls : Typed.Decls) let seen := collectCalls decls seen scrut bs.attach.foldl (fun s ⟨(_, b), _⟩ => collectCalls decls s b) seen | .add _ _ a b | .sub _ _ a b | .mul _ _ a b - | .u8Xor _ _ a b | .u8Add _ _ a b | .u8Sub _ _ a b + | .u8Xor _ _ a b | .u8Add _ _ a b | .u8Mul _ _ a b | .u8Sub _ _ a b | .u8And _ _ a b | .u8Or _ _ a b | .u8LessThan _ _ a b | .u32LessThan _ _ a b => collectCalls decls (collectCalls decls seen a) b @@ -749,6 +754,8 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term (substInTypedTerm subst a) (substInTypedTerm subst b) | .u8Add τ e a b => .u8Add (Typ.instantiate subst τ) e (substInTypedTerm subst a) (substInTypedTerm subst b) + | .u8Mul τ e a b => .u8Mul (Typ.instantiate subst τ) e + (substInTypedTerm subst a) (substInTypedTerm subst b) | .u8Sub τ e a b => .u8Sub (Typ.instantiate subst τ) e (substInTypedTerm subst a) (substInTypedTerm subst b) | .u8And τ e a b => .u8And (Typ.instantiate subst τ) e diff --git a/Ix/Aiur/Compiler/Layout.lean b/Ix/Aiur/Compiler/Layout.lean index 5c425ea8..4b3164ff 100644 --- a/Ix/Aiur/Compiler/Layout.lean +++ b/Ix/Aiur/Compiler/Layout.lean @@ -185,7 +185,7 @@ def opLayout : Bytecode.Op → LayoutM Unit | .u8BitDecomposition _ => do pushDegrees $ .replicate 8 1; bumpAuxiliaries 8; bumpLookups | .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. | .u8And .. | .u8Or .. => do pushDegree 1; bumpAuxiliaries; bumpLookups - | .u8Add .. | .u8Sub .. => do pushDegrees #[1, 1]; bumpAuxiliaries 2; bumpLookups + | .u8Add .. | .u8Mul .. | .u8Sub .. => do pushDegrees #[1, 1]; bumpAuxiliaries 2; bumpLookups | .u8LessThan .. => do pushDegree 1; bumpAuxiliaries; bumpLookups | .u32LessThan .. => do pushDegree 1; bumpAuxiliaries 12; bumpLookups 6 | .debug .. => pure () diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index 43c98267..a50f4ef2 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -283,6 +283,9 @@ def toIndex | .u8Add _ _ i j => do let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j pushOp (.u8Add i j) 2 + | .u8Mul _ _ i j => do + let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j + pushOp (.u8Mul i j) 2 | .u8Sub _ _ i j => do let i ← expectIdx layoutMap bindings i; let j ← expectIdx layoutMap bindings j pushOp (.u8Sub i j) 2 diff --git a/Ix/Aiur/Compiler/Match.lean b/Ix/Aiur/Compiler/Match.lean index 1e148b83..3a92d00e 100644 --- a/Ix/Aiur/Compiler/Match.lean +++ b/Ix/Aiur/Compiler/Match.lean @@ -380,6 +380,7 @@ def typedToSimple : Term → Simple.Term | .u8ShiftRight τ e a => .u8ShiftRight τ e (typedToSimple a) | .u8Xor τ e a b => .u8Xor τ e (typedToSimple a) (typedToSimple b) | .u8Add τ e a b => .u8Add τ e (typedToSimple a) (typedToSimple b) + | .u8Mul τ e a b => .u8Mul τ e (typedToSimple a) (typedToSimple b) | .u8Sub τ e a b => .u8Sub τ e (typedToSimple a) (typedToSimple b) | .u8And τ e a b => .u8And τ e (typedToSimple a) (typedToSimple b) | .u8Or τ e a b => .u8Or τ e (typedToSimple a) (typedToSimple b) diff --git a/Ix/Aiur/Goldilocks.lean b/Ix/Aiur/Goldilocks.lean index 937bab60..3a98fdcd 100644 --- a/Ix/Aiur/Goldilocks.lean +++ b/Ix/Aiur/Goldilocks.lean @@ -56,6 +56,10 @@ def G.u8LessThan (a b : G) : G := if a.n < b.n then 1 else 0 def G.u8Add (a b : G) : G × G := (G.ofNat ((a.n + b.n) % 256), G.ofNat ((a.n + b.n) / 256)) +/-- u8 multiplication returns `(low byte, high byte)`. -/ +def G.u8Mul (a b : G) : G × G := + (G.ofNat ((a.n * b.n) % 256), G.ofNat ((a.n * b.n) / 256)) + /-- u8 subtraction returns `(result % 256, borrow)`. -/ def G.u8Sub (a b : G) : G × G := (G.ofNat ((a.n + 256 - b.n) % 256), if a.n < b.n then 1 else 0) diff --git a/Ix/Aiur/Interpret.lean b/Ix/Aiur/Interpret.lean index 45b788f5..d8a7a294 100644 --- a/Ix/Aiur/Interpret.lean +++ b/Ix/Aiur/Interpret.lean @@ -326,6 +326,14 @@ partial def interp (decls : Decls) (bindings : Bindings) : Term → InterpM Valu let overflow : Value := .field (if x >= 256 then 1 else 0) return .tuple #[sum, overflow] | _, _ => throwErr "u8Add: expected field values" + | .u8Mul t1 t2 => do + match (← interp decls bindings t1), (← interp decls bindings t2) with + | .field a, .field b => + let x := a.val.toUInt8.toNat * b.val.toUInt8.toNat + let lo : Value := .field (G.ofUInt8 x.toUInt8) + let hi : Value := .field (G.ofUInt8 (x / 256).toUInt8) + return .tuple #[lo, hi] + | _, _ => throwErr "u8Mul: expected field values" | .u8Sub t1 t2 => do match (← interp decls bindings t1), (← interp decls bindings t2) with | .field a, .field b => diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index 5a5cb2d9..c85a5b94 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -173,6 +173,7 @@ syntax "u8_shift_left" "(" aiur_trm ")" : a syntax "u8_shift_right" "(" aiur_trm ")" : aiur_trm syntax "u8_xor" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u8_add" "(" aiur_trm ", " aiur_trm ")" : aiur_trm +syntax "u8_mul" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u8_sub" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u8_and" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u8_or" "(" aiur_trm ", " aiur_trm ")" : aiur_trm @@ -292,6 +293,8 @@ partial def elabTrm : ElabStxCat `aiur_trm mkAppM ``Source.Term.u8Xor #[← elabTrm i, ← elabTrm j] | `(aiur_trm| u8_add($i:aiur_trm, $j:aiur_trm)) => do mkAppM ``Source.Term.u8Add #[← elabTrm i, ← elabTrm j] + | `(aiur_trm| u8_mul($i:aiur_trm, $j:aiur_trm)) => do + mkAppM ``Source.Term.u8Mul #[← elabTrm i, ← elabTrm j] | `(aiur_trm| u8_sub($i:aiur_trm, $j:aiur_trm)) => do mkAppM ``Source.Term.u8Sub #[← elabTrm i, ← elabTrm j] | `(aiur_trm| u8_and($i:aiur_trm, $j:aiur_trm)) => do @@ -480,6 +483,10 @@ where let i ← replaceToken old new i let j ← replaceToken old new j `(aiur_trm| u8_add($i, $j)) + | `(aiur_trm| u8_mul($i:aiur_trm, $j:aiur_trm)) => do + let i ← replaceToken old new i + let j ← replaceToken old new j + `(aiur_trm| u8_mul($i, $j)) | `(aiur_trm| u8_sub($i:aiur_trm, $j:aiur_trm)) => do let i ← replaceToken old new i let j ← replaceToken old new j diff --git a/Ix/Aiur/Semantics/BytecodeEval.lean b/Ix/Aiur/Semantics/BytecodeEval.lean index b1669c37..a45c377e 100644 --- a/Ix/Aiur/Semantics/BytecodeEval.lean +++ b/Ix/Aiur/Semantics/BytecodeEval.lean @@ -247,6 +247,11 @@ def evalOp (t : Bytecode.Toplevel) (fuel : Nat) (op : Op) (st : EvalState) : let sum := x.val.toUInt8.toNat + y.val.toUInt8.toNat let st1 := pushMap st (G.ofUInt8 sum.toUInt8) pure (pushMap st1 (if sum ≥ 256 then 1 else 0)) + | .u8Mul a b => do + let x ← readIdx st a; let y ← readIdx st b + let prod := x.val.toUInt8.toNat * y.val.toUInt8.toNat + let st1 := pushMap st (G.ofUInt8 prod.toUInt8) + pure (pushMap st1 (G.ofUInt8 (prod / 256).toUInt8)) | .u8Sub a b => do let x ← readIdx st a; let y ← readIdx st b let i := x.val.toUInt8; let j := y.val.toUInt8 diff --git a/Ix/Aiur/Semantics/SourceEval.lean b/Ix/Aiur/Semantics/SourceEval.lean index d2e6899a..c0cffea7 100644 --- a/Ix/Aiur/Semantics/SourceEval.lean +++ b/Ix/Aiur/Semantics/SourceEval.lean @@ -382,6 +382,14 @@ def interp (decls : Decls) (fuel : Nat) (bindings : Bindings) .field (if x >= 256 then 1 else 0)]) (interp decls fuel bindings t1 st) (fun st1 => interp decls fuel bindings t2 st1) + | .u8Mul t1 t2 => + combineFieldsResult + (fun a b => + let x := a.val.toUInt8.toNat * b.val.toUInt8.toNat + .tuple #[.field (G.ofUInt8 x.toUInt8), + .field (G.ofUInt8 (x / 256).toUInt8)]) + (interp decls fuel bindings t1 st) + (fun st1 => interp decls fuel bindings t2 st1) | .u8Sub t1 t2 => combineFieldsResult (fun a b => diff --git a/Ix/Aiur/Stages/Bytecode.lean b/Ix/Aiur/Stages/Bytecode.lean index acea8020..527e5431 100644 --- a/Ix/Aiur/Stages/Bytecode.lean +++ b/Ix/Aiur/Stages/Bytecode.lean @@ -37,6 +37,7 @@ inductive Op | u8ShiftRight : ValIdx → Op | u8Xor : ValIdx → ValIdx → Op | u8Add : ValIdx → ValIdx → Op + | u8Mul : ValIdx → ValIdx → Op | u8Sub : ValIdx → ValIdx → Op | u8And : ValIdx → ValIdx → Op | u8Or : ValIdx → ValIdx → Op diff --git a/Ix/Aiur/Stages/Concrete.lean b/Ix/Aiur/Stages/Concrete.lean index 1faa134b..3e9eaa50 100644 --- a/Ix/Aiur/Stages/Concrete.lean +++ b/Ix/Aiur/Stages/Concrete.lean @@ -78,6 +78,7 @@ inductive Term : Type where | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term | u8Xor (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Add (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term + | u8Mul (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Sub (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8And (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Or (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term @@ -98,7 +99,7 @@ def Term.typ : Term → Typ | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ | .ioRead t _ _ _ | .ioWrite t _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ - | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Sub t _ _ _ + | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ | .debug t _ _ _ _ => t @@ -114,7 +115,7 @@ def Term.escapes : Term → Bool | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ | .ioRead _ e _ _ | .ioWrite _ e _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ - | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Sub _ e _ _ + | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ | .debug _ e _ _ _ => e diff --git a/Ix/Aiur/Stages/Simple.lean b/Ix/Aiur/Stages/Simple.lean index 2aed45b8..996b5d21 100644 --- a/Ix/Aiur/Stages/Simple.lean +++ b/Ix/Aiur/Stages/Simple.lean @@ -74,6 +74,7 @@ inductive Term : Type where | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term | u8Xor (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Add (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term + | u8Mul (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Sub (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8And (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Or (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term @@ -93,7 +94,7 @@ def Term.typ : Term → Typ | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ | .ioRead t _ _ _ | .ioWrite t _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ - | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Sub t _ _ _ + | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ | .debug t _ _ _ _ => t @@ -108,7 +109,7 @@ def Term.escapes : Term → Bool | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ | .ioRead _ e _ _ | .ioWrite _ e _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ - | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Sub _ e _ _ + | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ | .debug _ e _ _ _ => e diff --git a/Ix/Aiur/Stages/Source.lean b/Ix/Aiur/Stages/Source.lean index 440fa683..f1209021 100644 --- a/Ix/Aiur/Stages/Source.lean +++ b/Ix/Aiur/Stages/Source.lean @@ -382,6 +382,7 @@ inductive Term | u8ShiftRight : Term → Term | u8Xor : Term → Term → Term | u8Add : Term → Term → Term + | u8Mul : Term → Term → Term | u8Sub : Term → Term → Term | u8And : Term → Term → Term | u8Or : Term → Term → Term diff --git a/Ix/Aiur/Stages/Typed.lean b/Ix/Aiur/Stages/Typed.lean index 9ba4dd82..2a3c6f33 100644 --- a/Ix/Aiur/Stages/Typed.lean +++ b/Ix/Aiur/Stages/Typed.lean @@ -50,6 +50,7 @@ inductive Term : Type where | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term | u8Xor (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Add (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term + | u8Mul (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Sub (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8And (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u8Or (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term @@ -69,7 +70,7 @@ def Term.typ : Term → Typ | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ | .ioRead t _ _ _ | .ioWrite t _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ - | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Sub t _ _ _ + | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ | .debug t _ _ _ _ => t @@ -84,7 +85,7 @@ def Term.escapes : Term → Bool | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ | .ioRead _ e _ _ | .ioWrite _ e _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ - | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Sub _ e _ _ + | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ | .debug _ e _ _ _ => e diff --git a/Ix/IxVM/ByteStream.lean b/Ix/IxVM/ByteStream.lean index f3cc1d93..d657a890 100644 --- a/Ix/IxVM/ByteStream.lean +++ b/Ix/IxVM/ByteStream.lean @@ -99,8 +99,9 @@ def byteStream := ⟦ } } - -- `u64` addition with carry propagation (little-endian bytes) - fn u64_add(a: U64, b: U64) -> U64 { + -- `u64` addition with carry propagation (little-endian bytes). + -- Returns the sum together with the final carry-out. + fn u64_add(a: U64, b: U64) -> (U64, G) { let [a0, a1, a2, a3, a4, a5, a6, a7] = a; let [b0, b1, b2, b3, b4, b5, b6, b7] = b; let (s0, c1) = u8_add(a0, b0); @@ -122,9 +123,10 @@ def byteStream := ⟦ let (t6, o6) = u8_add(a6, b6); let (s6, c6a) = u8_add(t6, c6); let c7 = u8_xor(o6, c6a); - let (t7, _) = u8_add(a7, b7); - let (s7, _) = u8_add(t7, c7); - [s0, s1, s2, s3, s4, s5, s6, s7] + let (t7, o7) = u8_add(a7, b7); + let (s7, c7a) = u8_add(t7, c7); + let final_carry = u8_xor(o7, c7a); + ([s0, s1, s2, s3, s4, s5, s6, s7], final_carry) } -- `u64` subtraction via repeated decrement (correct for small b) diff --git a/Ix/IxVM/Kernel/Check.lean b/Ix/IxVM/Kernel/Check.lean index 648b2a47..7f14e84c 100644 --- a/Ix/IxVM/Kernel/Check.lean +++ b/Ix/IxVM/Kernel/Check.lean @@ -59,11 +59,7 @@ def check := ⟦ KExprNode.Srt(_) => 1, KExprNode.Const(idx, _) => let ci = load(list_lookup(top, idx)); - let u = is_unsafe_ci(ci); - match u { - 0 => 1, - 1 => 0, - }, + 1 - is_unsafe_ci(ci), KExprNode.App(f, a) => safe_refs_only(f, top) * safe_refs_only(a, top), KExprNode.Lam(t, b) => diff --git a/Ix/IxVM/Kernel/Primitive.lean b/Ix/IxVM/Kernel/Primitive.lean index ed571090..abb6c0a0 100644 --- a/Ix/IxVM/Kernel/Primitive.lean +++ b/Ix/IxVM/Kernel/Primitive.lean @@ -693,48 +693,13 @@ def primitive := ⟦ 0xa4, 0x00, 0xfb, 0x06, 0x03, 0xae, 0xdf, 0x2f] } - -- Mirror: `u64_add` in ByteStream.lean expanded to expose final - -- carry. Used by klimbs_succ / klimbs_add. - -- - -- TODO: delete this once `ByteStream.lean::u64_add` is patched to - -- return `(U64, G)` and existing call sites updated. Tracking this - -- as a follow-up because the patch ripples beyond the kernel - -- (Blake3 / IxonSerialize / ByteStream itself). - fn u64_add_with_carry(a: U64, b: U64) -> (U64, G) { - let [a0, a1, a2, a3, a4, a5, a6, a7] = a; - let [b0, b1, b2, b3, b4, b5, b6, b7] = b; - let (s0, c1) = u8_add(a0, b0); - let (t1, o1) = u8_add(a1, b1); - let (s1, c1a) = u8_add(t1, c1); - let c2 = u8_xor(o1, c1a); - let (t2, o2) = u8_add(a2, b2); - let (s2, c2a) = u8_add(t2, c2); - let c3 = u8_xor(o2, c2a); - let (t3, o3) = u8_add(a3, b3); - let (s3, c3a) = u8_add(t3, c3); - let c4 = u8_xor(o3, c3a); - let (t4, o4) = u8_add(a4, b4); - let (s4, c4a) = u8_add(t4, c4); - let c5 = u8_xor(o4, c4a); - let (t5, o5) = u8_add(a5, b5); - let (s5, c5a) = u8_add(t5, c5); - let c6 = u8_xor(o5, c5a); - let (t6, o6) = u8_add(a6, b6); - let (s6, c6a) = u8_add(t6, c6); - let c7 = u8_xor(o6, c6a); - let (t7, o7) = u8_add(a7, b7); - let (s7, c7a) = u8_add(t7, c7); - let final_carry = u8_xor(o7, c7a); - ([s0, s1, s2, s3, s4, s5, s6, s7], final_carry) - } - -- Mirror: BigUint::succ. Increment a KLimbs by 1; ripple carry. fn klimbs_succ(n: KLimbs) -> KLimbs { match load(n) { ListNode.Nil => store(ListNode.Cons([1, 0, 0, 0, 0, 0, 0, 0], store(ListNode.Nil))), ListNode.Cons(limb, rest) => - let pair = u64_add_with_carry(limb, [1, 0, 0, 0, 0, 0, 0, 0]); + let pair = u64_add(limb, [1, 0, 0, 0, 0, 0, 0, 0]); match pair { (sum, carry) => match carry { @@ -764,10 +729,10 @@ def primitive := ⟦ _ => klimbs_succ(a), }, ListNode.Cons(lb, rb) => - let pair1 = u64_add_with_carry(la, lb); + let pair1 = u64_add(la, lb); match pair1 { (sum1, carry1) => - let pair2 = u64_add_with_carry(sum1, [carry, 0, 0, 0, 0, 0, 0, 0]); + let pair2 = u64_add(sum1, [carry, 0, 0, 0, 0, 0, 0, 0]); match pair2 { (sum2, carry2) => let total_carry = g_or(carry1, carry2); @@ -908,10 +873,10 @@ def primitive := ⟦ klimbs_sub(a, one) } - -- TODO(u8_mul_gadget): replace `divmod_256` + byte-schoolbook `u64_mul` - -- with a proper u8_mul Aiur gadget once it lands. Tracking on a separate - -- branch. - -- Returns (remainder, quotient) where remainder = x mod 256, quotient = x / 256. + -- Returns (remainder, quotient): remainder = x mod 256, quotient = x / 256. + -- Repeated subtraction. Only ever invoked from the `#split_carry` / + -- `#split_u32` unconstrained witness generators, so the O(x/256) + -- iteration cost is off-circuit (untraced). fn divmod_256(x: G, q: G) -> (G, G) { match u32_less_than(x, 256) { 1 => (x, q), @@ -919,60 +884,153 @@ def primitive := ⟦ } } - -- u64×u64 → (lo: U64, hi: U64) via byte schoolbook. + -- Unconstrained witness generator: split `x` into its low byte `limb` + -- and the two bytes (clo, chi) of `x div 256`. Always invoked as + -- `#split_carry(...)`; the result is prover-provided and MUST be pinned + -- by the caller with u8 range checks + a reconstruction assert. The + -- division here is off-circuit (untraced), so its cost is irrelevant. + fn split_carry(x: G) -> (G, G, G) { + match divmod_256(x, 0) { + (limb, quot) => + match divmod_256(quot, 0) { + (clo, chi) => (limb, clo, chi), + }, + } + } + + -- u64×u64 → (lo: U64, hi: U64) via byte schoolbook. Faithful port of the + -- `MulWitness`/`Product` reference: column k is the raw field sum + -- Σ_{i+j=k} a[i]*b[j]; each column accumulator `out` is decomposed by a + -- prover-provided (unconstrained) split into result byte + 16-bit carry, + -- then pinned by three u8 range checks (`u8_xor(_, 0)`) and the + -- reconstruction assert `out == limb + 256·clo + 65536·chi`. No `u8_mul` + -- gadget and no constrained division. Column accumulators are < 2^19, so + -- the decomposition into (limb, clo, chi) ∈ [0,256)³ is unique → sound. fn u64_mul(a: U64, b: U64) -> (U64, U64) { let [a0, a1, a2, a3, a4, a5, a6, a7] = a; let [b0, b1, b2, b3, b4, b5, b6, b7] = b; - let pp0 = a0*b0; - let pp1 = a0*b1 + a1*b0; - let pp2 = a0*b2 + a1*b1 + a2*b0; - let pp3 = a0*b3 + a1*b2 + a2*b1 + a3*b0; - let pp4 = a0*b4 + a1*b3 + a2*b2 + a3*b1 + a4*b0; - let pp5 = a0*b5 + a1*b4 + a2*b3 + a3*b2 + a4*b1 + a5*b0; - let pp6 = a0*b6 + a1*b5 + a2*b4 + a3*b3 + a4*b2 + a5*b1 + a6*b0; - let pp7 = a0*b7 + a1*b6 + a2*b5 + a3*b4 + a4*b3 + a5*b2 + a6*b1 + a7*b0; - let pp8 = a1*b7 + a2*b6 + a3*b5 + a4*b4 + a5*b3 + a6*b2 + a7*b1; - let pp9 = a2*b7 + a3*b6 + a4*b5 + a5*b4 + a6*b3 + a7*b2; - let pp10 = a3*b7 + a4*b6 + a5*b5 + a6*b4 + a7*b3; - let pp11 = a4*b7 + a5*b6 + a6*b5 + a7*b4; - let pp12 = a5*b7 + a6*b6 + a7*b5; - let pp13 = a6*b7 + a7*b6; - let pp14 = a7*b7; - match divmod_256(pp0, 0) { - (r0, c1) => - match divmod_256(pp1 + c1, 0) { - (r1, c2) => - match divmod_256(pp2 + c2, 0) { - (r2, c3) => - match divmod_256(pp3 + c3, 0) { - (r3, c4) => - match divmod_256(pp4 + c4, 0) { - (r4, c5) => - match divmod_256(pp5 + c5, 0) { - (r5, c6) => - match divmod_256(pp6 + c6, 0) { - (r6, c7) => - match divmod_256(pp7 + c7, 0) { - (r7, c8) => - match divmod_256(pp8 + c8, 0) { - (r8, c9) => - match divmod_256(pp9 + c9, 0) { - (r9, c10) => - match divmod_256(pp10 + c10, 0) { - (r10, c11) => - match divmod_256(pp11 + c11, 0) { - (r11, c12) => - match divmod_256(pp12 + c12, 0) { - (r12, c13) => - match divmod_256(pp13 + c13, 0) { - (r13, c14) => - match divmod_256(pp14 + c14, 0) { - (r14, c15) => - match divmod_256(c15, 0) { - (r15, _) => - ([r0, r1, r2, r3, r4, r5, r6, r7], - [r8, r9, r10, r11, r12, r13, r14, r15]), - }, + let col0 = (a0 * b0); + let col1 = (a0 * b1) + (a1 * b0); + let col2 = (a0 * b2) + (a1 * b1) + (a2 * b0); + let col3 = (a0 * b3) + (a1 * b2) + (a2 * b1) + (a3 * b0); + let col4 = (a0 * b4) + (a1 * b3) + (a2 * b2) + (a3 * b1) + (a4 * b0); + let col5 = (a0 * b5) + (a1 * b4) + (a2 * b3) + (a3 * b2) + (a4 * b1) + (a5 * b0); + let col6 = (a0 * b6) + (a1 * b5) + (a2 * b4) + (a3 * b3) + (a4 * b2) + (a5 * b1) + (a6 * b0); + let col7 = (a0 * b7) + (a1 * b6) + (a2 * b5) + (a3 * b4) + (a4 * b3) + (a5 * b2) + (a6 * b1) + (a7 * b0); + let col8 = (a1 * b7) + (a2 * b6) + (a3 * b5) + (a4 * b4) + (a5 * b3) + (a6 * b2) + (a7 * b1); + let col9 = (a2 * b7) + (a3 * b6) + (a4 * b5) + (a5 * b4) + (a6 * b3) + (a7 * b2); + let col10 = (a3 * b7) + (a4 * b6) + (a5 * b5) + (a6 * b4) + (a7 * b3); + let col11 = (a4 * b7) + (a5 * b6) + (a6 * b5) + (a7 * b4); + let col12 = (a5 * b7) + (a6 * b6) + (a7 * b5); + let col13 = (a6 * b7) + (a7 * b6); + let col14 = (a7 * b7); + match #split_carry(col0) { + (rl0, rc0, rh0) => + let r0 = u8_xor(rl0, 0); + let lo0 = u8_xor(rc0, 0); + let hi0 = u8_xor(rh0, 0); + assert_eq!(col0, r0 + (256 * lo0) + (65536 * hi0)); + let out1 = col1 + lo0 + (256 * hi0); + match #split_carry(out1) { + (rl1, rc1, rh1) => + let r1 = u8_xor(rl1, 0); + let lo1 = u8_xor(rc1, 0); + let hi1 = u8_xor(rh1, 0); + assert_eq!(out1, r1 + (256 * lo1) + (65536 * hi1)); + let out2 = col2 + lo1 + (256 * hi1); + match #split_carry(out2) { + (rl2, rc2, rh2) => + let r2 = u8_xor(rl2, 0); + let lo2 = u8_xor(rc2, 0); + let hi2 = u8_xor(rh2, 0); + assert_eq!(out2, r2 + (256 * lo2) + (65536 * hi2)); + let out3 = col3 + lo2 + (256 * hi2); + match #split_carry(out3) { + (rl3, rc3, rh3) => + let r3 = u8_xor(rl3, 0); + let lo3 = u8_xor(rc3, 0); + let hi3 = u8_xor(rh3, 0); + assert_eq!(out3, r3 + (256 * lo3) + (65536 * hi3)); + let out4 = col4 + lo3 + (256 * hi3); + match #split_carry(out4) { + (rl4, rc4, rh4) => + let r4 = u8_xor(rl4, 0); + let lo4 = u8_xor(rc4, 0); + let hi4 = u8_xor(rh4, 0); + assert_eq!(out4, r4 + (256 * lo4) + (65536 * hi4)); + let out5 = col5 + lo4 + (256 * hi4); + match #split_carry(out5) { + (rl5, rc5, rh5) => + let r5 = u8_xor(rl5, 0); + let lo5 = u8_xor(rc5, 0); + let hi5 = u8_xor(rh5, 0); + assert_eq!(out5, r5 + (256 * lo5) + (65536 * hi5)); + let out6 = col6 + lo5 + (256 * hi5); + match #split_carry(out6) { + (rl6, rc6, rh6) => + let r6 = u8_xor(rl6, 0); + let lo6 = u8_xor(rc6, 0); + let hi6 = u8_xor(rh6, 0); + assert_eq!(out6, r6 + (256 * lo6) + (65536 * hi6)); + let out7 = col7 + lo6 + (256 * hi6); + match #split_carry(out7) { + (rl7, rc7, rh7) => + let r7 = u8_xor(rl7, 0); + let lo7 = u8_xor(rc7, 0); + let hi7 = u8_xor(rh7, 0); + assert_eq!(out7, r7 + (256 * lo7) + (65536 * hi7)); + let out8 = col8 + lo7 + (256 * hi7); + match #split_carry(out8) { + (rl8, rc8, rh8) => + let r8 = u8_xor(rl8, 0); + let lo8 = u8_xor(rc8, 0); + let hi8 = u8_xor(rh8, 0); + assert_eq!(out8, r8 + (256 * lo8) + (65536 * hi8)); + let out9 = col9 + lo8 + (256 * hi8); + match #split_carry(out9) { + (rl9, rc9, rh9) => + let r9 = u8_xor(rl9, 0); + let lo9 = u8_xor(rc9, 0); + let hi9 = u8_xor(rh9, 0); + assert_eq!(out9, r9 + (256 * lo9) + (65536 * hi9)); + let out10 = col10 + lo9 + (256 * hi9); + match #split_carry(out10) { + (rl10, rc10, rh10) => + let r10 = u8_xor(rl10, 0); + let lo10 = u8_xor(rc10, 0); + let hi10 = u8_xor(rh10, 0); + assert_eq!(out10, r10 + (256 * lo10) + (65536 * hi10)); + let out11 = col11 + lo10 + (256 * hi10); + match #split_carry(out11) { + (rl11, rc11, rh11) => + let r11 = u8_xor(rl11, 0); + let lo11 = u8_xor(rc11, 0); + let hi11 = u8_xor(rh11, 0); + assert_eq!(out11, r11 + (256 * lo11) + (65536 * hi11)); + let out12 = col12 + lo11 + (256 * hi11); + match #split_carry(out12) { + (rl12, rc12, rh12) => + let r12 = u8_xor(rl12, 0); + let lo12 = u8_xor(rc12, 0); + let hi12 = u8_xor(rh12, 0); + assert_eq!(out12, r12 + (256 * lo12) + (65536 * hi12)); + let out13 = col13 + lo12 + (256 * hi12); + match #split_carry(out13) { + (rl13, rc13, rh13) => + let r13 = u8_xor(rl13, 0); + let lo13 = u8_xor(rc13, 0); + let hi13 = u8_xor(rh13, 0); + assert_eq!(out13, r13 + (256 * lo13) + (65536 * hi13)); + let out14 = col14 + lo13 + (256 * hi13); + match #split_carry(out14) { + (rl14, rc14, rh14) => + let r14 = u8_xor(rl14, 0); + let lo14 = u8_xor(rc14, 0); + let hi14 = u8_xor(rh14, 0); + assert_eq!(out14, r14 + (256 * lo14) + (65536 * hi14)); + let r15 = lo14 + (256 * hi14); + ([r0, r1, r2, r3, r4, r5, r6, r7], + [r8, r9, r10, r11, r12, r13, r14, r15]), }, }, }, @@ -1016,9 +1074,9 @@ def primitive := ⟦ ListNode.Cons(b_limb, rest) => match u64_mul(a_limb, b_limb) { (lo, hi) => - match u64_add_with_carry(lo, carry) { + match u64_add(lo, carry) { (sum, carry_out) => - match u64_add_with_carry(hi, [carry_out, 0, 0, 0, 0, 0, 0, 0]) { + match u64_add(hi, [carry_out, 0, 0, 0, 0, 0, 0, 0]) { (new_carry, _) => let new_acc = list_snoc(acc, sum); klimbs_mul_single(a_limb, rest, new_carry, new_acc), @@ -2135,8 +2193,11 @@ def primitive := ⟦ } } - -- Convert G value (≤ 2^32) into single-limb KLimbs via byte decomp. - fn klimbs_from_g(x: G) -> KLimbs { + -- Unconstrained witness generator: split `x` (< 2^32) into 4 + -- little-endian bytes. Always invoked as `#split_u32(...)`; the result + -- is prover-provided and MUST be pinned by the caller with u8 range + -- checks + a reconstruction assert. Division is off-circuit (untraced). + fn split_u32(x: G) -> (G, G, G, G) { match divmod_256(x, 0) { (b0, q1) => match divmod_256(q1, 0) { @@ -2144,15 +2205,30 @@ def primitive := ⟦ match divmod_256(q2, 0) { (b2, q3) => match divmod_256(q3, 0) { - (b3, _q4) => - store(ListNode.Cons([b0, b1, b2, b3, 0, 0, 0, 0], - store(ListNode.Nil))), + (b3, _) => (b0, b1, b2, b3), }, }, }, } } + -- Convert G value (< 2^32) into single-limb KLimbs. The 4-byte + -- decomposition is a prover-provided (unconstrained) witness, pinned by + -- four u8 range checks + the reconstruction assert. `x >= 2^32` is + -- rejected (assert fails) rather than silently truncated. + fn klimbs_from_g(x: G) -> KLimbs { + match #split_u32(x) { + (rb0, rb1, rb2, rb3) => + let b0 = u8_xor(rb0, 0); + let b1 = u8_xor(rb1, 0); + let b2 = u8_xor(rb2, 0); + let b3 = u8_xor(rb3, 0); + assert_eq!(x, b0 + (256 * b1) + (65536 * b2) + (16777216 * b3)); + store(ListNode.Cons([b0, b1, b2, b3, 0, 0, 0, 0], + store(ListNode.Nil))), + } + } + -- Walk byte stream forward decoding UTF-8 codepoints; return last. -- Empty → 65 ('A') per Rust default. fn utf8_last_codepoint(bs: ByteStream) -> G { diff --git a/Tests/Aiur/Aiur.lean b/Tests/Aiur/Aiur.lean index a1a01c7f..6a0913b3 100644 --- a/Tests/Aiur/Aiur.lean +++ b/Tests/Aiur/Aiur.lean @@ -330,6 +330,10 @@ def toplevel := ⟦ u8_sub(i, j) } + pub fn u8_mul_function(i: G, j: G) -> (G, G) { + u8_mul(i, j) + } + pub fn u8_less_than_function(i: G, j: G) -> G { u8_less_than(i, j) } @@ -848,6 +852,7 @@ def aiurTestCases : List AiurTestCase := [ .noIO `shr_shr_shl_decompose #[87] #[0, 1, 0, 1, 0, 1, 0, 0], .noIO `u8_add_xor #[45, 131] #[219, 0, 49, 1], .noIO `u8_sub_function #[45, 131] #[170, 1], + .noIO `u8_mul_function #[45, 131] #[7, 23], .noIO `u8_less_than_function #[45, 131] #[1], .noIO `u8_and_function #[45, 131] #[1], .noIO `u8_or_function #[45, 131] #[175], diff --git a/Tests/Aiur/Cross.lean b/Tests/Aiur/Cross.lean index 4960da32..25d1e988 100644 --- a/Tests/Aiur/Cross.lean +++ b/Tests/Aiur/Cross.lean @@ -375,6 +375,7 @@ def toplevel : Source.Toplevel := ⟦ -- u8 op single-call wrappers pub fn u8_sub_function(i: G, j: G) -> (G, G) { u8_sub(i, j) } + pub fn u8_mul_function(i: G, j: G) -> (G, G) { u8_mul(i, j) } pub fn u8_less_than_function(i: G, j: G) -> G { u8_less_than(i, j) } pub fn u8_and_function(i: G, j: G) -> G { u8_and(i, j) } pub fn u8_or_function(i: G, j: G) -> G { u8_or(i, j) } @@ -1188,6 +1189,8 @@ def tests : TestSeq := runAgreement "assert_eq_trivial" "assert_eq_trivial" [] ++ runAgreement "store_and_load(42)" "store_and_load" [42] ++ runAgreement "u8_sub_function(45,131)" "u8_sub_function" [45, 131] ++ + runAgreement "u8_mul_function(45,131)" "u8_mul_function" [45, 131] ++ + runAgreement "u8_mul_function(255,255)" "u8_mul_function" [255, 255] ++ runAgreement "u8_less_than_function(45,131)" "u8_less_than_function" [45, 131] ++ runAgreement "u8_less_than_function(131,45)" "u8_less_than_function" [131, 45] ++ runAgreement "u8_and_function(45,131)" "u8_and_function" [45, 131] ++ diff --git a/Tests/Ix/IxVM.lean b/Tests/Ix/IxVM.lean index ef1df320..974b4828 100644 --- a/Tests/Ix/IxVM.lean +++ b/Tests/Ix/IxVM.lean @@ -29,6 +29,7 @@ namespace IxVMPrim public theorem nat_add_lit : 100 + 200 = 300 := rfl public theorem nat_sub_lit : 1000 - 250 = 750 := rfl public theorem nat_mul_lit : Nat.mul 6 7 = 42 := rfl +public theorem nat_mul_big : Nat.mul 1000000 1000000 = 1000000000000 := rfl public theorem nat_div_lit : Nat.div 100 7 = 14 := rfl public theorem nat_mod_lit : Nat.mod 100 7 = 2 := rfl public theorem nat_succ_lit : Nat.succ 41 = 42 := rfl @@ -111,6 +112,7 @@ private def kernelCheckNames : List String := [ "Nat.sub_le_of_le_add", -- Primitive reduction theorems (`IxVMPrim`) "IxVMPrim.nat_add_lit", "IxVMPrim.nat_sub_lit", "IxVMPrim.nat_mul_lit", + "IxVMPrim.nat_mul_big", "IxVMPrim.nat_div_lit", "IxVMPrim.nat_mod_lit", "IxVMPrim.nat_succ_lit", "IxVMPrim.nat_pred_lit", "IxVMPrim.nat_gcd_lit", "IxVMPrim.nat_land_lit", "IxVMPrim.nat_lor_lit", "IxVMPrim.nat_xor_lit", diff --git a/src/aiur.rs b/src/aiur.rs index c4794b85..2e904331 100644 --- a/src/aiur.rs +++ b/src/aiur.rs @@ -69,3 +69,8 @@ pub fn u8_less_than_channel() -> G { pub fn u8_range_check_channel() -> G { G::from_u8(11) } + +#[inline] +pub fn u8_mul_channel() -> G { + G::from_u8(12) +} diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 31d7a88e..10e27cf7 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -52,6 +52,7 @@ pub enum Op { U8ShiftRight(ValIdx), U8Xor(ValIdx, ValIdx), U8Add(ValIdx, ValIdx), + U8Mul(ValIdx, ValIdx), U8Sub(ValIdx, ValIdx), U8And(ValIdx, ValIdx), U8Or(ValIdx, ValIdx), diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index 777eb099..4f0d6c68 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -18,9 +18,9 @@ use crate::{ bytes2::{Bytes2, Bytes2Op}, }, memory_channel, u8_add_channel, u8_and_channel, - u8_bit_decomposition_channel, u8_less_than_channel, u8_or_channel, - u8_range_check_channel, u8_shift_left_channel, u8_shift_right_channel, - u8_sub_channel, u8_xor_channel, + u8_bit_decomposition_channel, u8_less_than_channel, u8_mul_channel, + u8_or_channel, u8_range_check_channel, u8_shift_left_channel, + u8_shift_right_channel, u8_sub_channel, u8_xor_channel, }, }; @@ -492,6 +492,14 @@ impl Op { sel.clone(), state, ), + Op::U8Mul(i, j) => bytes2_constraints( + *i, + *j, + &Bytes2Op::Mul, + u8_mul_channel(), + sel.clone(), + state, + ), Op::U8Sub(i, j) => bytes2_constraints( *i, *j, diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index 9fec900b..9e607ebc 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -363,6 +363,14 @@ impl Function { bytes2_execute(*i, *j, &Bytes2Op::Add, &mut map, record); } }, + ExecEntry::Op(Op::U8Mul(i, j)) => { + if unconstrained { + let (lo, hi) = Bytes2::mul(&map[*i], &map[*j]); + map.extend([lo, hi]); + } else { + bytes2_execute(*i, *j, &Bytes2Op::Mul, &mut map, record); + } + }, ExecEntry::Op(Op::U8Sub(i, j)) => { if unconstrained { let (r, u) = Bytes2::sub(&map[*i], &map[*j]); diff --git a/src/aiur/gadgets/bytes2.rs b/src/aiur/gadgets/bytes2.rs index c3d382c8..e0701f71 100644 --- a/src/aiur/gadgets/bytes2.rs +++ b/src/aiur/gadgets/bytes2.rs @@ -8,8 +8,8 @@ use multi_stark::{ use crate::aiur::{ G, execute::QueryRecord, gadgets::AiurGadget, u8_add_channel, u8_and_channel, - u8_less_than_channel, u8_or_channel, u8_range_check_channel, u8_sub_channel, - u8_xor_channel, + u8_less_than_channel, u8_mul_channel, u8_or_channel, u8_range_check_channel, + u8_sub_channel, u8_xor_channel, }; /// Number of columns in the trace with multiplicities for @@ -20,7 +20,8 @@ use crate::aiur::{ /// - or /// - less_than /// - range_check -const TRACE_WIDTH: usize = 7; +/// - mul +const TRACE_WIDTH: usize = 8; /// Number of columns in the preprocessed trace: /// - first raw byte value @@ -33,7 +34,9 @@ const TRACE_WIDTH: usize = 7; /// - and result /// - or result /// - less_than result -const PREPROCESSED_TRACE_WIDTH: usize = 10; +/// - mul low byte +/// - mul high byte +const PREPROCESSED_TRACE_WIDTH: usize = 12; /// AIR implementer for arity 2 byte-related lookups. pub(crate) struct Bytes2; @@ -41,6 +44,7 @@ pub(crate) struct Bytes2; pub(crate) enum Bytes2Op { Xor, Add, + Mul, Sub, And, Or, @@ -83,6 +87,11 @@ impl BaseAir for Bytes2 { // Less than trace_values.push(G::from_bool(i < j)); + + // Mul (low byte, high byte) + let p = u16::from(i) * u16::from(j); + trace_values.push(G::from_u8((p & 0xff) as u8)); + trace_values.push(G::from_u8((p >> 8) as u8)); } } Some(RowMajorMatrix::new(trace_values, PREPROCESSED_TRACE_WIDTH)) @@ -100,7 +109,7 @@ impl AiurGadget for Bytes2 { fn output_size(&self, op: &Bytes2Op) -> usize { match op { Bytes2Op::Xor | Bytes2Op::And | Bytes2Op::Or | Bytes2Op::LessThan => 1, - Bytes2Op::Add | Bytes2Op::Sub => 2, + Bytes2Op::Add | Bytes2Op::Sub | Bytes2Op::Mul => 2, } } @@ -122,6 +131,11 @@ impl AiurGadget for Bytes2 { let (r, o) = Self::add(i, j); vec![r, o] }, + Bytes2Op::Mul => { + record.bytes2_queries.bump_mul(i, j); + let (lo, hi) = Self::mul(i, j); + vec![lo, hi] + }, Bytes2Op::Sub => { record.bytes2_queries.bump_sub(i, j); let (r, u) = Self::sub(i, j); @@ -151,6 +165,7 @@ impl AiurGadget for Bytes2 { let or_channel = u8_or_channel().into(); let less_than_channel = u8_less_than_channel().into(); let range_check_channel = u8_range_check_channel().into(); + let mul_channel = u8_mul_channel().into(); // Multiplicity columns let xor_multiplicity = var(0); @@ -160,6 +175,7 @@ impl AiurGadget for Bytes2 { let or_multiplicity = var(4); let less_than_multiplicity = var(5); let range_check_multiplicity = var(6); + let mul_multiplicity = var(7); // Preprocessed columns let i = preprocessed_var(0); @@ -172,6 +188,8 @@ impl AiurGadget for Bytes2 { let and = preprocessed_var(7); let or = preprocessed_var(8); let less_than = preprocessed_var(9); + let mul_lo = preprocessed_var(10); + let mul_hi = preprocessed_var(11); let pull_xor = Lookup::pull( xor_multiplicity, @@ -201,6 +219,11 @@ impl AiurGadget for Bytes2 { vec![less_than_channel, i.clone(), j.clone(), less_than], ); + let pull_mul = Lookup::pull( + mul_multiplicity, + vec![mul_channel, i.clone(), j.clone(), mul_lo, mul_hi], + ); + let pull_range_check = Lookup::pull(range_check_multiplicity, vec![range_check_channel, i, j]); @@ -212,6 +235,7 @@ impl AiurGadget for Bytes2 { pull_or, pull_less_than, pull_range_check, + pull_mul, ] } @@ -231,6 +255,7 @@ impl AiurGadget for Bytes2 { let or_channel = u8_or_channel(); let less_than_channel = u8_less_than_channel(); let range_check_channel = u8_range_check_channel(); + let mul_channel = u8_mul_channel(); rows .chunks_exact_mut(TRACE_WIDTH) @@ -239,7 +264,10 @@ impl AiurGadget for Bytes2 { .zip(&mut lookups) .for_each( |( - ((row_idx, row), &[xor, add, sub, and, or, less_than, range_check]), + ( + (row_idx, row), + &[xor, add, sub, and, or, less_than, range_check, mul], + ), row_lookups, )| { let i = G::from_usize(row_idx / 256); @@ -252,6 +280,7 @@ impl AiurGadget for Bytes2 { row[4] = or; row[5] = less_than; row[6] = range_check; + row[7] = mul; // Pull xor. row_lookups[0] = @@ -282,6 +311,10 @@ impl AiurGadget for Bytes2 { // Pull range_check. row_lookups[6] = Lookup::pull(range_check, vec![range_check_channel, i, j]); + + // Pull mul. + let (lo, hi) = Self::mul(&i, &j); + row_lookups[7] = Lookup::pull(mul, vec![mul_channel, i, j, lo, hi]); }, ); (RowMajorMatrix::new(rows, TRACE_WIDTH), lookups) @@ -325,6 +358,10 @@ impl Bytes2Queries { self.bump_multiplicity_for(i, j, 6) } + fn bump_mul(&mut self, i: &G, j: &G) { + self.bump_multiplicity_for(i, j, 7) + } + fn bump_multiplicity_for(&mut self, i: &G, j: &G, col: usize) { let i = usize::try_from(i.as_canonical_u64()).unwrap(); let j = usize::try_from(j.as_canonical_u64()).unwrap(); @@ -377,4 +414,13 @@ impl Bytes2 { let j: u8 = j.as_canonical_u64().try_into().unwrap(); G::from_bool(i < j) } + + /// `u8 * u8 -> (low byte, high byte)`. The product fits in 16 bits. + #[inline] + pub(crate) fn mul(i: &G, j: &G) -> (G, G) { + let i: u8 = i.as_canonical_u64().try_into().unwrap(); + let j: u8 = j.as_canonical_u64().try_into().unwrap(); + let p = u16::from(i) * u16::from(j); + (G::from_u8((p & 0xff) as u8), G::from_u8((p >> 8) as u8)) + } } diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 5cc6f9f6..02d5a928 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -20,9 +20,9 @@ use crate::{ gadgets::{bytes1::Bytes1, bytes2::Bytes2}, memory::Memory, u8_add_channel, u8_and_channel, u8_bit_decomposition_channel, - u8_less_than_channel, u8_or_channel, u8_range_check_channel, - u8_shift_left_channel, u8_shift_right_channel, u8_sub_channel, - u8_xor_channel, + u8_less_than_channel, u8_mul_channel, u8_or_channel, + u8_range_check_channel, u8_shift_left_channel, u8_shift_right_channel, + u8_sub_channel, u8_xor_channel, }, }; @@ -426,6 +426,17 @@ impl Op { let lookup_args = vec![u8_add_channel(), i, j, r, o]; slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); }, + Op::U8Mul(i, j) => { + let (i, _) = map[*i]; + let (j, _) = map[*j]; + let (lo, hi) = Bytes2::mul(&i, &j); + map.push((lo, 1)); + map.push((hi, 1)); + slice.push_auxiliary(index, lo); + slice.push_auxiliary(index, hi); + let lookup_args = vec![u8_mul_channel(), i, j, lo, hi]; + slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); + }, Op::U8Sub(i, j) => { let (i, _) = map[*i]; let (j, _) = map[*j]; diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index 02b8a4e8..e1f0df73 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -107,25 +107,29 @@ fn decode_op(ctor: LeanCtor>) -> Op { }, 18 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U8Sub(i, j) + Op::U8Mul(i, j) }, 19 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U8And(i, j) + Op::U8Sub(i, j) }, 20 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U8Or(i, j) + Op::U8And(i, j) }, 21 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U8LessThan(i, j) + Op::U8Or(i, j) }, 22 => { let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); - Op::U32LessThan(i, j) + Op::U8LessThan(i, j) }, 23 => { + let [i, j] = ctor.objs::<2>().map(|x| lean_unbox_nat_as_usize(&x)); + Op::U32LessThan(i, j) + }, + 24 => { let [label_obj, idxs_obj] = ctor.objs::<2>(); let label = label_obj.as_string().to_string(); let idxs = if idxs_obj.is_scalar() {