UNPKG

o1js

Version:

TypeScript framework for zk-SNARKs and zkApps

378 lines (337 loc) 9.9 kB
import { assert } from '../../lib/util/assert.js'; import { bytesToBigInt, log2 } from './bigint-helpers.js'; import { randomBytes } from './random.js'; export { createField, Fp, Fq, FiniteField, p, q, mod, inverse }; // CONSTANTS // the modulus. called `p` in most of our code. const p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001n; const q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001n; // this is `t`, where p = 2^32 * t + 1 const pMinusOneOddFactor = 0x40000000000000000000000000000000224698fc094cf91b992d30edn; const qMinusOneOddFactor = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb21n; // primitive roots of unity, computed as (5^t mod p). this works because 5 generates the multiplicative group mod p const twoadicRootFp = 0x2bce74deac30ebda362120830561f81aea322bf2b7bb7584bdad6fabd87ea32fn; const twoadicRootFq = 0x2de6a9b8746d3f589e5c4dfd492ae26e9bb97ea3c106f049a70e2c1102b6d05fn; // GENERAL FINITE FIELD ALGORITHMS function mod(x: bigint, p: bigint) { x = x % p; if (x < 0) return x + p; return x; } // modular exponentiation, a^n % p function power(a: bigint, n: bigint, p: bigint) { a = mod(a, p); let x = 1n; for (; n > 0n; n >>= 1n) { if (n & 1n) x = mod(x * a, p); a = mod(a * a, p); } return x; } // inverting with EGCD, 1/a in Z_p function inverse(a: bigint, p: bigint) { a = mod(a, p); if (a === 0n) return undefined; let b = p; let x = 0n; let y = 1n; let u = 1n; let v = 0n; while (a !== 0n) { let q = b / a; let r = mod(b, a); let m = x - u * q; let n = y - v * q; b = a; a = r; x = u; y = v; u = m; v = n; } if (b !== 1n) return undefined; return mod(x, p); } // faster inversion algorithm based on // Thomas Pornin, "Optimized Binary GCD for Modular Inversion", https://eprint.iacr.org/2020/972.pdf // about 3x faster than `inverse()` function fastInverse( x: bigint, p: bigint, n: number, kmax: bigint, twoToMinusKmax: bigint ) { x = mod(x, p); if (x === 0n) return undefined; // fixed constants const w = 31; const hiBits = 31; const wn = BigInt(w); const wMask = (1n << wn) - 1n; let u = p; let v = x; let r = 0n; let s = 1n; let i = 0; for (; i < 2 * n; i++) { let f0 = 1; let g0 = 0; let f1 = 0; let g1 = 1; let ulo = Number(u & wMask); let vlo = Number(v & wMask); let len = Math.max(log2(u), log2(v)); let shift = BigInt(Math.max(len - hiBits, 0)); let uhi = Number(u >> shift); let vhi = Number(v >> shift); for (let j = 0; j < w; j++) { if ((ulo & 1) === 0) { uhi >>= 1; ulo >>= 1; f1 <<= 1; g1 <<= 1; } else if ((vlo & 1) === 0) { vhi >>= 1; vlo >>= 1; f0 <<= 1; g0 <<= 1; } else { if (vhi <= uhi) { uhi = (uhi - vhi) >> 1; ulo = (ulo - vlo) >> 1; f0 = f0 + f1; g0 = g0 + g1; f1 <<= 1; g1 <<= 1; } else { vhi = (vhi - uhi) >> 1; vlo = (vlo - ulo) >> 1; f1 = f0 + f1; g1 = g0 + g1; f0 <<= 1; g0 <<= 1; } } } let f0n = BigInt(f0); let g0n = BigInt(g0); let f1n = BigInt(f1); let g1n = BigInt(g1); let unew = u * f0n - v * g0n; let vnew = v * g1n - u * f1n; u = unew >> wn; v = vnew >> wn; if (u < 0) (u = -u), (f0n = -f0n), (g0n = -g0n); if (v < 0) (v = -v), (f1n = -f1n), (g1n = -g1n); let rnew = r * f0n + s * g0n; let snew = s * g1n + r * f1n; r = rnew; s = snew; // these assertions are all true, enable when debugging: // let lin = v * r + u * s; // assert(lin === p || lin === -p, 'linear combination'); // let k = BigInt((i + 1) * w); // assert(mod(x * r + u * 2n ** k, p) === 0n, 'mod p, r'); // assert(mod(x * s - v * 2n ** k, p) === 0n, 'mod p, s'); if (u === 0n) break; // empirically this never happens, but there might be unlucky edge cases where it does, due to sign flips if (v === 0n) { assert(u === 1n, 'u = 1'); s = mod(-r, p); break; } } let k = BigInt((i + 1) * w); // now s = 2^k/x mod p // correction step to go from 2^k/x to 1/x s = mod(s * twoToMinusKmax, p); // s <- s * 2^(-kmax) = 2^(k - kmax)/x s = mod(s << (kmax - k), p); // s <- s * 2^(kmax - k) = 1/x // yes this has a slight cost and the assert is never triggered, // but it's worth having for the sake of assurance assert(mod(x * s - 1n, p) === 0n, 'mod p'); return s; } function sqrt(n: bigint, p: bigint, Q: bigint, c: bigint, M: bigint) { // https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm#The_algorithm // variable naming is the same as in that link ^ // Q is what we call `t` elsewhere - the odd factor in p - 1 // c is a known primitive root of unity // M is the twoadicity = exponent of 2 in factorization of p - 1 if (n === 0n) return 0n; let t = power(n, (Q - 1n) >> 1n, p); // n^(Q - 1)/2 let R = mod(t * n, p); // n^((Q - 1)/2 + 1) = n^((Q + 1)/2) t = mod(t * R, p); // n^((Q - 1)/2 + (Q + 1)/2) = n^Q while (true) { if (t === 1n) return R; // use repeated squaring to find the least i, 0 < i < M, such that t^(2^i) = 1 let i = 0n; let s = t; while (s !== 1n) { s = mod(s * s, p); i = i + 1n; } if (i === M) return undefined; // no solution let b = power(c, 1n << (M - i - 1n), p); // c^(2^(M-i-1)) M = i; c = mod(b * b, p); t = mod(t * c, p); R = mod(R * b, p); } } function isSquare(x: bigint, p: bigint) { if (x === 0n) return true; let sqrt1 = power(x, (p - 1n) / 2n, p); return sqrt1 === 1n; } function randomField(p: bigint, sizeInBytes: number, hiBitMask: number) { // strategy: find random 255-bit bigints and use the first that's smaller than p while (true) { let bytes = randomBytes(sizeInBytes); bytes[sizeInBytes - 1] &= hiBitMask; // zero highest bit, so we get 255 random bits let x = bytesToBigInt(bytes); if (x < p) return x; } } // SPECIALIZATIONS TO FP, FQ // these should be mostly trivial const Fp = createField(p, { oddFactor: pMinusOneOddFactor, twoadicRoot: twoadicRootFp, twoadicity: 32n, }); const Fq = createField(q, { oddFactor: qMinusOneOddFactor, twoadicRoot: twoadicRootFq, twoadicity: 32n, }); type FiniteField = ReturnType<typeof createField>; function createField( p: bigint, constants?: { oddFactor: bigint; twoadicRoot: bigint; twoadicity: bigint } ) { let { oddFactor, twoadicRoot, twoadicity } = constants ?? computeFieldConstants(p); let sizeInBits = log2(p); let sizeInBytes = Math.ceil(sizeInBits / 8); let sizeHighestByte = sizeInBits - 8 * (sizeInBytes - 1); let hiBitMask = (1 << sizeHighestByte) - 1; // parameters for fast inverse const w = 31; const n = Math.ceil(sizeInBits / w); const kmax = BigInt(2 * n * w); // constant for correcting 2^k/x -> 1/x, by multiplying with 2^-kmax * 2^(kmax - k) const twoToMinusKmax = inverse(1n << kmax, p); const exportedInverse = twoToMinusKmax !== undefined ? (x: bigint) => fastInverse(x, p, n, kmax, twoToMinusKmax) : (x: bigint) => inverse(x, p); return { modulus: p, sizeInBits, t: oddFactor, M: twoadicity, twoadicRoot, mod(x: bigint) { return mod(x, p); }, add(x: bigint, y: bigint) { return mod(x + y, p); }, not(x: bigint, bits: number) { return mod(2n ** BigInt(bits) - (x + 1n), p); }, negate(x: bigint) { return x === 0n ? 0n : p - x; }, sub(x: bigint, y: bigint) { return mod(x - y, p); }, mul(x: bigint, y: bigint) { return mod(x * y, p); }, inverse: exportedInverse, div(x: bigint, y: bigint) { let yinv = exportedInverse(y); if (yinv === undefined) return; return mod(x * yinv, p); }, square(x: bigint) { return mod(x * x, p); }, isSquare(x: bigint) { return isSquare(x, p); }, sqrt(x: bigint) { return sqrt(x, p, oddFactor, twoadicRoot, twoadicity); }, power(x: bigint, n: bigint) { return power(x, n, p); }, dot(x: bigint[], y: bigint[]) { let z = 0n; let n = x.length; for (let i = 0; i < n; i++) { z += x[i] * y[i]; } return mod(z, p); }, equal(x: bigint, y: bigint) { return mod(x - y, p) === 0n; }, isEven(x: bigint) { return !(x & 1n); }, random() { return randomField(p, sizeInBytes, hiBitMask); }, fromNumber(x: number) { return mod(BigInt(x), p); }, fromBigint(x: bigint) { return mod(x, p); }, rot( x: bigint, bits: bigint, direction: 'left' | 'right' = 'left', maxBits = 64n ) { if (direction === 'right') bits = maxBits - bits; let full = x << bits; let excess = full >> maxBits; let shifted = full & ((1n << maxBits) - 1n); return shifted | excess; }, leftShift(x: bigint, bits: number, maxBitSize: number = 64) { let shifted = x << BigInt(bits); return shifted & ((1n << BigInt(maxBitSize)) - 1n); }, rightShift(x: bigint, bits: number) { return x >> BigInt(bits); }, }; } /** * Compute constants to instantiate a finite field just from the modulus */ function computeFieldConstants(p: bigint) { // figure out the factorization p - 1 = 2^M * t let oddFactor = p - 1n; let twoadicity = 0n; while ((oddFactor & 1n) === 0n) { oddFactor >>= 1n; twoadicity++; } // find z = non-square // start with 2 and increment until we find one let z = 2n; while (isSquare(z, p)) z++; // primitive root of unity is z^t let twoadicRoot = power(z, oddFactor, p); return { oddFactor, twoadicRoot, twoadicity }; }