UNPKG

o1js

Version:

TypeScript framework for zk-SNARKs and zkApps

553 lines 23.7 kB
import { inverse, mod } from '../../../bindings/crypto/finite-field.js'; import { Field } from '../field.js'; import { Provable } from '../provable.js'; import { assert } from './common.js'; import { Field3, ForeignField, split, weakBound } from './foreign-field.js'; import { l, l2, l2Mask, multiRangeCheck } from './range-check.js'; import { sha256 } from 'js-sha256'; import { bigIntToBytes, bytesToBigInt } from '../../../bindings/crypto/bigint-helpers.js'; import { affineAdd, affineDouble, } from '../../../bindings/crypto/elliptic-curve.js'; import { Bool } from '../bool.js'; import { provable } from '../types/provable-derivers.js'; import { assertPositiveInteger } from '../../../bindings/crypto/non-negative.js'; import { arrayGetGeneric, assertNotVectorEquals } from './basic.js'; import { sliceField3 } from './bit-slices.js'; import { exists } from '../core/exists.js'; // external API export { EllipticCurve, Point, Ecdsa }; // internal API export { verifyEcdsaConstant, initialAggregator, simpleMapToCurve }; const EllipticCurve = { add, double, negate, assertOnCurve, scale, assertInSubgroup, multiScalarMul, }; function add(p1, p2, Curve) { let { x: x1, y: y1 } = p1; let { x: x2, y: y2 } = p2; let f = Curve.modulus; let [f0, f1, f2] = split(f); let [, , fx22] = split(f * 2n); // constant case if (Point.isConstant(p1) && Point.isConstant(p2)) { let p3 = affineAdd(Point.toBigint(p1), Point.toBigint(p2), f, Curve.a); return Point.from(p3); } assert(Curve.modulus > l2Mask + 1n, 'Base field moduli smaller than 2^176 are not supported'); // witness and range-check slope, x3, y3 let witnesses = exists(9, () => { let [x1_, x2_, y1_, y2_] = Field3.toBigints(x1, x2, y1, y2); let denom = inverse(mod(x1_ - x2_, f), f) ?? 0n; let m = mod((y1_ - y2_) * denom, f); let x3 = mod(m * m - x1_ - x2_, f); let y3 = mod(m * (x1_ - x3) - y1_, f); return [...split(m), ...split(x3), ...split(y3)]; }); let [m0, m1, m2, x30, x31, x32, y30, y31, y32] = witnesses; let m = [m0, m1, m2]; let x3 = [x30, x31, x32]; let y3 = [y30, y31, y32]; ForeignField.assertAlmostReduced([m, x3, y3], f); // check that x1 != x2 // we assume x1, x2 are almost reduced, so deltaX <= x1 - x2 + f < 3f // which means we need to check that deltaX != 0, f, 2f let deltaX = ForeignField.sub(x1, x2, f); let deltaX01 = deltaX[0].add(deltaX[1].mul(1n << l)).seal(); assertNotVectorEquals([deltaX01, deltaX[2]], [0n, 0n]); // != 0 assertNotVectorEquals([deltaX01, deltaX[2]], [f0 + (f1 << l), f2]); // != f deltaX[2].assertNotEquals(fx22); // != 2f (stronger check bc assuming deltaX < f doesn't harm completeness) // (x1 - x2)*m = y1 - y2 let deltaY = ForeignField.Sum(y1).sub(y2); ForeignField.assertMul(deltaX, m, deltaY, f); // m^2 = x1 + x2 + x3 let xSum = ForeignField.Sum(x1).add(x2).add(x3); ForeignField.assertMul(m, m, xSum, f); // (x1 - x3)*m = y1 + y3 let deltaX1X3 = ForeignField.Sum(x1).sub(x3); let ySum = ForeignField.Sum(y1).add(y3); ForeignField.assertMul(deltaX1X3, m, ySum, f); return { x: x3, y: y3 }; } function double(p1, Curve) { let { x: x1, y: y1 } = p1; let f = Curve.modulus; // constant case if (Point.isConstant(p1)) { let p3 = affineDouble(Point.toBigint(p1), f, Curve.a); return Point.from(p3); } // witness and range-check slope, x3, y3 let witnesses = exists(9, () => { let [x1_, y1_] = Field3.toBigints(x1, y1); let denom = inverse(mod(2n * y1_, f), f) ?? 0n; let m = mod((3n * mod(x1_ ** 2n, f) + Curve.a) * denom, f); let x3 = mod(m * m - 2n * x1_, f); let y3 = mod(m * (x1_ - x3) - y1_, f); return [...split(m), ...split(x3), ...split(y3)]; }); let [m0, m1, m2, x30, x31, x32, y30, y31, y32] = witnesses; let m = [m0, m1, m2]; let x3 = [x30, x31, x32]; let y3 = [y30, y31, y32]; ForeignField.assertAlmostReduced([m, x3, y3], f); // x1^2 = x1x1 let x1x1 = ForeignField.mul(x1, x1, f); // 2*y1*m = 3*x1x1 + a let y1Times2 = ForeignField.Sum(y1).add(y1); let x1x1Times3PlusA = ForeignField.Sum(x1x1).add(x1x1).add(x1x1); if (Curve.a !== 0n) x1x1Times3PlusA = x1x1Times3PlusA.add(Field3.from(Curve.a)); ForeignField.assertMul(y1Times2, m, x1x1Times3PlusA, f); // m^2 = 2*x1 + x3 let xSum = ForeignField.Sum(x1).add(x1).add(x3); ForeignField.assertMul(m, m, xSum, f); // (x1 - x3)*m = y1 + y3 let deltaX1X3 = ForeignField.Sum(x1).sub(x3); let ySum = ForeignField.Sum(y1).add(y3); ForeignField.assertMul(deltaX1X3, m, ySum, f); return { x: x3, y: y3 }; } function negate({ x, y }, Curve) { return { x, y: ForeignField.negate(y, Curve.modulus) }; } function assertOnCurve(p, { modulus: f, a, b }) { let { x, y } = p; let x2 = ForeignField.mul(x, x, f); // Ensure x2, x, and y are almost reduced to prevent potential exploitation // by a malicious prover adding large multiples of f, which could violate // the precondition of ForeignField.assertMul ForeignField.assertAlmostReduced([x2, x, y], f); let y2 = ForeignField.mul(y, y, f); let y2MinusB = ForeignField.Sum(y2).sub(Field3.from(b)); // (x^2 + a) * x = y^2 - b let x2PlusA = ForeignField.Sum(x2); if (a !== 0n) x2PlusA = x2PlusA.add(Field3.from(a)); let message; if (Point.isConstant(p)) { message = `assertOnCurve(): (${x}, ${y}) is not on the curve.`; } ForeignField.assertMul(x2PlusA, x, y2MinusB, f, message); } /** * EC scalar multiplication, `scalar*point` * * The result is constrained to be not zero. */ function scale(scalar, point, Curve, config = { mode: 'assert-nonzero' }) { config.windowSize ??= Point.isConstant(point) ? 4 : 3; return multiScalarMul([scalar], [point], Curve, [config], config.mode); } // checks whether the elliptic curve point g is in the subgroup defined by [order]g = 0 function assertInSubgroup(p, Curve) { if (!Curve.hasCofactor) return; scale(Field3.from(Curve.order), p, Curve, { mode: 'assert-zero' }); } // check whether a point equals a constant point // TODO implement the full case of two vars function equals(p1, p2, Curve) { let xEquals = ForeignField.equals(p1.x, p2.x, Curve.modulus); let yEquals = ForeignField.equals(p1.y, p2.y, Curve.modulus); return xEquals.and(yEquals); } function verifyEcdsaGeneric(Curve, signature, msgHash, publicKey, multiScalarMul, config = { G: { windowSize: 4 }, P: { windowSize: 4 } }) { // constant case if (EcdsaSignature.isConstant(signature) && Field3.isConstant(msgHash) && Point.isConstant(publicKey)) { let isValid = verifyEcdsaConstant(Curve, EcdsaSignature.toBigint(signature), Field3.toBigint(msgHash), Point.toBigint(publicKey)); return new Bool(isValid); } // provable case // note: usually we don't check validity of inputs, like that the public key is a valid curve point // we make an exception for the two non-standard conditions r != 0 and s != 0, // which are unusual to capture in types and could be considered part of the verification algorithm let { r, s } = signature; ForeignField.inv(r, Curve.order); // proves r != 0 (important, because r = 0 => u2 = 0 kills the private key contribution) let sInv = ForeignField.inv(s, Curve.order); // proves s != 0 let u1 = ForeignField.mul(msgHash, sInv, Curve.order); let u2 = ForeignField.mul(r, sInv, Curve.order); let G = Point.from(Curve.one); let R = multiScalarMul([u1, u2], [G, publicKey], Curve, config && [config.G, config.P], 'assert-nonzero', config?.ia); // this ^ already proves that R != 0 (part of ECDSA verification) // reduce R.x modulo the curve order let Rx = ForeignField.mul(R.x, Field3.from(1n), Curve.order); // we have to prove that Rx is canonical, because we check signature validity based on whether Rx _exactly_ equals the input r. // if we allowed non-canonical Rx, the prover could make verify() return false on a valid signature, by adding a multiple of `Curve.order` to Rx. ForeignField.assertLessThan(Rx, Curve.order); // assert s to be canonical ForeignField.assertLessThan(s, Curve.order); return Provable.equal(Field3, Rx, r); } /** * Verify an ECDSA signature. * * Details about the `config` parameter: * - For both the generator point `G` and public key `P`, `config` allows you to specify: * - the `windowSize` which is used in scalar multiplication for this point. * this flexibility is good because the optimal window size is different for constant and non-constant points. * empirically, `windowSize=4` for constants and 3 for variables leads to the fewest constraints. * our defaults reflect that the generator is always constant and the public key is variable in typical applications. * - a table of multiples of those points, of length `2^windowSize`, which is used in the scalar multiplication gadget to speed up the computation. * if these are not provided, they are computed on the fly. * for the constant G, computing multiples costs no constraints, so passing them in makes no real difference. * for variable public key, there is a possible use case: if the public key is a public input, then its multiples could also be. * in that case, passing them in would avoid computing them in-circuit and save a few constraints. * - The initial aggregator `ia`, see {@link initialAggregator}. By default, `ia` is computed deterministically on the fly. * * * _Note_: If `signature.s` is a non-canonical element, an error will be thrown. * If `signature.r` is non-canonical, however, `false` will be returned. */ function verifyEcdsa(Curve, signature, msgHash, publicKey, config = { G: { windowSize: 4 }, P: { windowSize: 3 } }) { return verifyEcdsaGeneric(Curve, signature, msgHash, publicKey, (scalars, points, Curve, configs, mode, ia) => multiScalarMul(scalars, points, Curve, configs, mode, ia), config); } /** * Bigint implementation of ECDSA verify */ function verifyEcdsaConstant(Curve, { r, s }, msgHash, publicKey) { let pk = Curve.from(publicKey); if (Curve.equal(pk, Curve.zero)) return false; if (Curve.hasCofactor && !Curve.isInSubgroup(pk)) return false; if (r < 1n || r >= Curve.order) return false; if (s < 1n || s >= Curve.order) return false; let sInv = Curve.Scalar.inverse(s); assert(sInv !== undefined); let u1 = Curve.Scalar.mul(msgHash, sInv); let u2 = Curve.Scalar.mul(r, sInv); let R = Curve.add(Curve.scale(Curve.one, u1), Curve.scale(pk, u2)); if (Curve.equal(R, Curve.zero)) return false; return Curve.Scalar.equal(R.x, r); } function multiScalarMulConstant(scalars, points, Curve, mode = 'assert-nonzero') { let n = points.length; assert(scalars.length === n, 'Points and scalars lengths must match'); assertPositiveInteger(n, 'Expected at least 1 point and scalar'); let useGlv = Curve.hasEndomorphism; // TODO dedicated MSM let s = scalars.map(Field3.toBigint); let P = points.map(Point.toBigint); let sum = Curve.zero; for (let i = 0; i < n; i++) { if (useGlv) { sum = Curve.add(sum, Curve.Endo.scale(P[i], s[i])); } else { sum = Curve.add(sum, Curve.scale(P[i], s[i])); } } if (mode === 'assert-zero') { assert(sum.infinity, 'scalar multiplication: expected zero result'); return Point.from(Curve.zero); } assert(!sum.infinity, 'scalar multiplication: expected non-zero result'); return Point.from(sum); } /** * Multi-scalar multiplication: * * s_0 * P_0 + ... + s_(n-1) * P_(n-1) * * where P_i are any points. * * By default, we prove that the result is not zero. * * If you set the `mode` parameter to `'assert-zero'`, on the other hand, * we assert that the result is zero and just return the constant zero point. * * Implementation: We double all points together and leverage a precomputed table of size 2^c to avoid all but every cth addition. * * Note: this algorithm targets a small number of points, like 2 needed for ECDSA verification. * * TODO: could use lookups for picking precomputed multiples, instead of O(2^c) provable switch * TODO: custom bit representation for the scalar that avoids 0, to get rid of the degenerate addition case */ function multiScalarMul(scalars, points, Curve, tableConfigs = [], mode = 'assert-nonzero', ia) { let n = points.length; assert(scalars.length === n, 'Points and scalars lengths must match'); assertPositiveInteger(n, 'Expected at least 1 point and scalar'); let useGlv = Curve.hasEndomorphism; // constant case if (scalars.every(Field3.isConstant) && points.every(Point.isConstant)) { return multiScalarMulConstant(scalars, points, Curve, mode); } // parse or build point tables let windowSizes = points.map((_, i) => tableConfigs[i]?.windowSize ?? 1); let tables = points.map((P, i) => getPointTable(Curve, P, windowSizes[i], tableConfigs[i]?.multiples)); let maxBits = Curve.Scalar.sizeInBits; if (useGlv) { maxBits = Curve.Endo.decomposeMaxBits; // decompose scalars and handle signs let n2 = 2 * n; let scalars2 = Array(n2); let points2 = Array(n2); let windowSizes2 = Array(n2); let tables2 = Array(n2); let mrcStack = []; for (let i = 0; i < n; i++) { let [s0, s1] = decomposeNoRangeCheck(Curve, scalars[i]); scalars2[2 * i] = s0.abs; scalars2[2 * i + 1] = s1.abs; let table = tables[i]; let endoTable = table.map((P, i) => { if (i === 0) return P; let [phiP, betaXBound] = endomorphism(Curve, P); mrcStack.push(betaXBound); return phiP; }); tables2[2 * i] = table.map((P) => negateIf(s0.isNegative, P, Curve.modulus)); tables2[2 * i + 1] = endoTable.map((P) => negateIf(s1.isNegative, P, Curve.modulus)); points2[2 * i] = tables2[2 * i][1]; points2[2 * i + 1] = tables2[2 * i + 1][1]; windowSizes2[2 * i] = windowSizes2[2 * i + 1] = windowSizes[i]; } reduceMrcStack(mrcStack); // from now on, everything is the same as if these were the original points and scalars points = points2; tables = tables2; scalars = scalars2; windowSizes = windowSizes2; n = n2; } // slice scalars let scalarChunks = scalars.map((s, i) => sliceField3(s, { maxBits, chunkSize: windowSizes[i] })); // initialize sum to the initial aggregator, which is expected to be unrelated to any point that this gadget is used with // note: this is a trick to ensure _completeness_ of the gadget // soundness follows because add() and double() are sound, on all inputs that are valid non-zero curve points ia ??= initialAggregator(Curve); let sum = Point.from(ia); for (let i = maxBits - 1; i >= 0; i--) { // add in multiple of each point for (let j = 0; j < n; j++) { let windowSize = windowSizes[j]; if (i % windowSize === 0) { // pick point to add based on the scalar chunk let sj = scalarChunks[j][i / windowSize]; let sjP = windowSize === 1 ? points[j] : arrayGetGeneric(Point.provable, tables[j], sj); // ec addition let added = add(sum, sjP, Curve); // handle degenerate case (if sj = 0, Gj is all zeros and the add result is garbage) sum = Provable.if(sj.equals(0), Point, sum, added); } } if (i === 0) break; // jointly double all points // (note: the highest couple of bits will not create any constraints because sum is constant; no need to handle that explicitly) sum = double(sum, Curve); } // the sum is now 2^(b-1)*IA + sum_i s_i*P_i // we assert that sum != 2^(b-1)*IA, and add -2^(b-1)*IA to get our result let iaFinal = Curve.scale(Curve.fromNonzero(ia), 1n << BigInt(maxBits - 1)); let isZero = equals(sum, iaFinal, Curve); if (mode === 'assert-nonzero') { isZero.assertFalse(); sum = add(sum, Point.from(Curve.negate(iaFinal)), Curve); } else { isZero.assertTrue(); // for type consistency with the 'assert-nonzero' case sum = Point.from(Curve.zero); } return sum; } function negateIf(condition, P, f) { let y = Provable.if(Bool.Unsafe.fromField(condition), Field3, ForeignField.negate(P.y, f), P.y); return { x: P.x, y }; } function endomorphism(Curve, P) { let beta = Field3.from(Curve.Endo.base); let betaX = ForeignField.mul(beta, P.x, Curve.modulus); return [{ x: betaX, y: P.y }, weakBound(betaX[2], Curve.modulus)]; } /** * Decompose s = s0 + s1*lambda where s0, s1 are guaranteed to be small * * Note: This assumes that s0 and s1 are range-checked externally; in scalar multiplication this happens because they are split into chunks. */ function decomposeNoRangeCheck(Curve, s) { assert(Curve.Endo.decomposeMaxBits < l2, 'decomposed scalars assumed to be < 2*88 bits'); // witness s0, s1 let witnesses = exists(6, () => { let [s0, s1] = Curve.Endo.decompose(Field3.toBigint(s)); let [s00, s01] = split(s0.abs); let [s10, s11] = split(s1.abs); // prettier-ignore return [ s0.isNegative ? 1n : 0n, s00, s01, s1.isNegative ? 1n : 0n, s10, s11, ]; }); let [s0Negative, s00, s01, s1Negative, s10, s11] = witnesses; // we can hard-code highest limb to zero // (in theory this would allow us to hard-code the high quotient limb to zero in the ffmul below, and save 2 RCs.. but not worth it) let s0 = [s00, s01, Field.from(0n)]; let s1 = [s10, s11, Field.from(0n)]; s0Negative.assertBool(); s1Negative.assertBool(); // prove that s1*lambda = s - s0 let lambda = Provable.if(Bool.Unsafe.fromField(s1Negative), Field3, Field3.from(Curve.Scalar.negate(Curve.Endo.scalar)), Field3.from(Curve.Endo.scalar)); let rhs = Provable.if(Bool.Unsafe.fromField(s0Negative), Field3, ForeignField.Sum(s).add(s0).finish(Curve.order), ForeignField.Sum(s).sub(s0).finish(Curve.order)); ForeignField.assertMul(s1, lambda, rhs, Curve.order); return [ { isNegative: s0Negative, abs: s0 }, { isNegative: s1Negative, abs: s1 }, ]; } /** * Sign a message hash using ECDSA. */ function signEcdsa(Curve, msgHash, privateKey) { let { Scalar } = Curve; let k = Scalar.random(); let R = Curve.scale(Curve.one, k); let r = Scalar.mod(R.x); let kInv = Scalar.inverse(k); assert(kInv !== undefined); let s = Scalar.mul(kInv, Scalar.add(msgHash, Scalar.mul(r, privateKey))); return { r, s }; } /** * Given a point P, create the list of multiples [0, P, 2P, 3P, ..., (2^windowSize-1) * P]. * This method is provable, but won't create any constraints given a constant point. */ function getPointTable(Curve, P, windowSize, table) { assertPositiveInteger(windowSize, 'invalid window size'); let n = 1 << windowSize; // n >= 2 assert(table === undefined || table.length === n, 'invalid table'); if (table !== undefined) return table; table = [Point.from(Curve.zero), P]; if (n === 2) return table; let Pi = double(P, Curve); table.push(Pi); for (let i = 3; i < n; i++) { Pi = add(Pi, P, Curve); table.push(Pi); } return table; } /** * For EC scalar multiplication we use an initial point which is subtracted * at the end, to avoid encountering the point at infinity. * * This is a simple hash-to-group algorithm which finds that initial point. * It's important that this point has no known discrete logarithm so that nobody * can create an invalid proof of EC scaling. */ function initialAggregator(Curve) { // hash that identifies the curve let h = sha256.create(); h.update('initial-aggregator'); h.update(bigIntToBytes(Curve.modulus)); h.update(bigIntToBytes(Curve.order)); h.update(bigIntToBytes(Curve.a)); h.update(bigIntToBytes(Curve.b)); let bytes = h.array(); // bytes represent a 256-bit number // use that as x coordinate const F = Curve.Field; let x = F.mod(bytesToBigInt(bytes)); return simpleMapToCurve(x, Curve); } function random(Curve) { let x = Curve.Field.random(); return simpleMapToCurve(x, Curve); } /** * Given an x coordinate (base field element), increment it until we find one with * a y coordinate that satisfies the curve equation, and return the point. * * If the curve has a cofactor, multiply by it to get a point in the correct subgroup. */ function simpleMapToCurve(x, Curve) { const F = Curve.Field; let y = undefined; // increment x until we find a y coordinate while (y === undefined) { x = F.add(x, 1n); // solve y^2 = x^3 + ax + b let x3 = F.mul(F.square(x), x); let y2 = F.add(x3, F.mul(Curve.a, x) + Curve.b); y = F.sqrt(y2); } let p = { x, y, infinity: false }; // clear cofactor if (Curve.hasCofactor) { p = Curve.scale(p, Curve.cofactor); } return p; } // type/conversion helpers const Point = { from({ x, y }) { return { x: Field3.from(x), y: Field3.from(y) }; }, toBigint({ x, y }) { return { x: Field3.toBigint(x), y: Field3.toBigint(y), infinity: false }; }, isConstant: (P) => Provable.isConstant(Point, P), /** * Random point on the curve. */ random(Curve) { return Point.from(random(Curve)); }, provable: provable({ x: Field3, y: Field3 }), }; const EcdsaSignature = { from({ r, s }) { return { r: Field3.from(r), s: Field3.from(s) }; }, toBigint({ r, s }) { return { r: Field3.toBigint(r), s: Field3.toBigint(s) }; }, isConstant: (S) => Provable.isConstant(EcdsaSignature, S), /** * Create an {@link EcdsaSignature} from a raw 130-char hex string as used in * [Ethereum transactions](https://ethereum.org/en/developers/docs/transactions/#typed-transaction-envelope). */ fromHex(rawSignature) { let prefix = rawSignature.slice(0, 2); let signature = rawSignature.slice(2, 130); if (prefix !== '0x' || signature.length < 128) { throw Error(`Signature.fromHex(): Invalid signature, expected hex string 0x... of length at least 130.`); } let r = BigInt(`0x${signature.slice(0, 64)}`); let s = BigInt(`0x${signature.slice(64)}`); return EcdsaSignature.from({ r, s }); }, provable: provable({ r: Field3, s: Field3 }), }; const Ecdsa = { sign: signEcdsa, verify: verifyEcdsa, Signature: EcdsaSignature, }; // MRC stack function reduceMrcStack(xs) { let n = xs.length; let nRemaining = n % 3; let nFull = (n - nRemaining) / 3; for (let i = 0; i < nFull; i++) { multiRangeCheck([xs[3 * i], xs[3 * i + 1], xs[3 * i + 2]]); } let remaining = [Field.from(0n), Field.from(0n), Field.from(0n)]; for (let i = 0; i < nRemaining; i++) { remaining[i] = xs[3 * nFull + i]; } multiRangeCheck(remaining); } //# sourceMappingURL=elliptic-curve.js.map