UNPKG

micro-zk-proofs

Version:

Create & verify zero-knowledge SNARK proofs in parallel, using noble cryptography

644 lines (626 loc) 23.2 kB
/** * The code is only used if you plan to run **legacy circom-js programs**. It is unused in WASM. * Minimal witness program executor for circom programs, based on websnark/wasmsnark/snarkjs. * Unsafe: it uses eval, better to be used inside worker threads. * Depends on **monkey-patched BigInt** prototypes due to how circom programs are serialized. * We only patch prototypes before execution. After finishing, patches are reverted. * @module */ import { invert, pow, type IField } from '@noble/curves/abstract/modular.js'; import { bn254 as nobleBn254 } from '@noble/curves/bn254.js'; import { bitMask } from '@noble/curves/utils.js'; import * as P from 'micro-packed'; import { type CircuitInfo, type Constraint, type G1Point, type G2Point, type ProvingKey, type VerificationKey, } from './index.ts'; import type { BlsCurvePair as BLSCurvePair } from '@noble/curves/abstract/bls.js'; import type { TArg, TRet } from '@noble/hashes/utils.js'; function monkeyPatchBigInt() { const methods = { // Equality eq: (a: bigint, b: bigint) => a === b, neq: (a: bigint, b: bigint) => a !== b, greaterOrEquals: (a: bigint, b: bigint) => a >= b, greater: (a: bigint, b: bigint) => a > b, gt: (a: bigint, b: bigint) => a > b, lesserOrEquals: (a: bigint, b: bigint) => a <= b, lesser: (a: bigint, b: bigint) => a < b, lt: (a: bigint, b: bigint) => a < b, // Basic math sub: (a: bigint, b: bigint) => a - b, add: (a: bigint, b: bigint) => a + b, mul: (a: bigint, b: bigint) => a * b, div: (a: bigint, b: bigint) => a / b, mod: (a: bigint, b: any) => a % b, // Fields inverse: (n: bigint, modulo: bigint) => invert(n, modulo), modPow: (a: bigint, power: bigint, modulo: bigint) => pow(a, power, modulo), // Binary and: (a: bigint, b: bigint) => a & b, shr: (a: bigint, b: bigint) => a >> BigInt(b), }; let patched = false; let orig: Record<string, PropertyDescriptor | undefined> = {}; const proto = BigInt.prototype as any; return { patch() { if (patched) throw new Error('bigint: already patched'); for (const name in methods) { // Preserve descriptors: callers may have accessors or own undefined-valued properties here. orig[name] = Object.getOwnPropertyDescriptor(proto, name); Object.defineProperty(proto, name, { configurable: true, enumerable: orig[name]?.enumerable || false, value: function (...args: any[]) { return (methods as any)[name](this, ...args); }, writable: true, }); } patched = true; }, restore() { if (!patched) throw new Error('bigint: not patched'); for (const name in methods) { if (!orig[name]) delete proto[name]; else Object.defineProperty(proto, name, orig[name]); } orig = {}; patched = false; }, }; } const selectorStr = (lst: string[]) => lst.map((i) => `[${i}]`).join(''); const signalStr = (name: string, selectors: string[]) => name + selectorStr(selectors); // Apply selectors const select = (a: any, selectors: string[]): any => { for (const s of selectors) a = a[s]; return a; }; type Scope = Record<string, any>; /** * Builds a witness generator for a legacy circom-js circuit JSON. * @param circJson - Circom circuit JSON artifact. * @returns Function that executes the circuit and returns the witness. * @example * Build a witness runner from a circom JSON circuit artifact. * ```ts * import { generateWitness } from 'micro-zk-proofs/witness.js'; * // Addition circuit: witness output is one, a + b, b, a. * const circuitJson = { * nVars: 4, * nInputs: 2, * nOutputs: 1, * nSignals: 4, * templates: { * Main: `function(ctx) { * ctx.setSignal( * "out", * [], * bigInt(ctx.getSignal("a", [])).add(bigInt(ctx.getSignal("b", []))).mod(__P__) * ); * }`, * }, * functions: {}, * components: [{ name: 'main', params: {}, template: 'Main', inputSignals: 2 }], * signals: [ * { names: ['one'], triggerComponents: [] }, * { names: ['main.out'], triggerComponents: [] }, * { names: ['main.b'], triggerComponents: [0] }, * { names: ['main.a'], triggerComponents: [0] }, * ], * signalName2Idx: { one: 0, 'main.out': 1, 'main.b': 2, 'main.a': 3 }, * }; * const witness = generateWitness(circuitJson)({ a: '33', b: '34' }); * // [1n, 67n, 34n, 33n] * ``` */ export function generateWitness(circJson: any): (input: any) => any { const P = nobleBn254.fields.Fr.ORDER; const MASK = bitMask(nobleBn254.fields.Fr.BITS); const signals = circJson.signals; const components = circJson.components; const templates: Record<string, Function> = {}; // Bind P & MASK directly into templates/functions, so we see dependency for (let t in circJson.templates) { templates[t] = new Function('bigInt', '__P__', '__MASK__', 'return ' + circJson.templates[t])( BigInt, P, MASK ); } const functions: Record<string, { params: any[]; func: Function }> = {}; for (let f in circJson.functions) { functions[f] = { params: circJson.functions[f].params, func: new Function('bigInt', '__P__', '__MASK__', 'return ' + circJson.functions[f].func)( BigInt, P, MASK ), }; } function inputIdx(i: any) { if (i >= circJson.nInputs) throw new Error('Accessing an invalid input: ' + i); // Witness slot 0 is the constant one, so declared inputs start after the output slots. return circJson.nOutputs + 1 + i; } function getSignalIdx(name: any) { if (circJson.signalName2Idx[name] !== undefined) return circJson.signalName2Idx[name]; // signalNames() also queries raw witness indices when building error messages. if (!isNaN(name)) return Number(name); throw new Error('Invalid signal identifier: ' + name); } const signalNames = (i: any) => signals[getSignalIdx(i)].names.join(', '); const patcher = monkeyPatchBigInt(); return function (input: any): any { const witness = new Array(circJson.nSignals); let currentComponent: string | undefined; let scopes: Scope[] = []; // scope stack const notInitSignals = {} as any; function inScope(newScope: Scope, cb: Function) { const oldScope = scopes; scopes = [scopes[0], newScope]; const res = cb(); scopes = oldScope; return res; } function triggerComponent(c: any) { notInitSignals[c]--; const oldComponent = currentComponent; currentComponent = components[c].name; const template = components[c].template; const newScope: any = {}; for (let p in components[c].params) newScope[p] = components[c].params[p]; inScope(newScope, () => templates[template](ctx)); currentComponent = oldComponent; } function setSignalFullName(fullName: any, value: any) { const sId = getSignalIdx(fullName); let firstInit = false; if (witness[sId] === undefined) firstInit = true; witness[sId] = BigInt(value); const callComponents = []; for (let i = 0; i < signals[sId].triggerComponents.length; i++) { var idCmp = signals[sId].triggerComponents[i]; if (firstInit) notInitSignals[idCmp]--; callComponents.push(idCmp); } callComponents.map((c) => { if (notInitSignals[c] == 0) triggerComponent(c); }); return witness[sId]; } function getSignalFullName(name: string) { const id = getSignalIdx(name); if (witness[id] === undefined) throw new Error('Signal not initialized: ' + name); return witness[id]; } const cName = (name: string) => (name == 'one' ? 'one' : currentComponent + '.' + name); // Minimal API that used inside evaluated code const ctx = { // Pins setPin(compName: string, compSel: string[], sigName: string, sigSel: string[], value: any) { const name = signalStr(cName(compName), compSel) + '.' + signalStr(sigName, sigSel); setSignalFullName(name, value); }, getPin(compName: string, componentSels: string[], sigName: string, sigSel: string[]) { const name = signalStr(cName(compName), componentSels) + '.' + signalStr(sigName, sigSel); return getSignalFullName(name); }, // Vars setVar(name: string, sels: string[], value: any) { const scope = scopes[scopes.length - 1]; if (sels.length == 0) { scope[name] = value; } else { if (scope[name] === undefined) scope[name] = []; let cur = scope[name]; for (let i = 0; i < sels.length - 1; i++) { if (cur[sels[i]] === undefined) cur[sels[i]] = []; cur = cur[sels[i]]; } cur[sels[sels.length - 1]] = value; } return value; }, getVar(name: string, sels: string[]) { for (let i = scopes.length - 1; i >= 0; i--) if (scopes[i][name] !== undefined) return select(scopes[i][name], sels); throw new Error('Variable not defined: ' + name); }, // Signals setSignal(name: string, sels: string[], value: any) { setSignalFullName( signalStr(currentComponent ? currentComponent + '.' + name : name, sels), value ); }, getSignal(name: string, sels: string[]) { return getSignalFullName(signalStr(cName(name), sels)); }, // Utils callFunction(name: string, params: any) { const newScope: Record<string, any> = {}; for (let p = 0; p < functions[name].params.length; p++) newScope[functions[name].params[p]] = params[p]; return inScope(newScope, () => functions[name].func(ctx)); }, assert(a: any, b: any, errStr: string = '') { a = BigInt(a); b = BigInt(b); if (a === b) return; throw new Error(`Constraint doesn't match ${currentComponent}: ${errStr} -> ${a} != ${b}`); }, }; patcher.patch(); try { // Processing for (const c in components) notInitSignals[c] = components[c].inputSignals; ctx.setSignal('one', [], BigInt(1)); for (let c in notInitSignals) if (notInitSignals[c] == 0) triggerComponent(c); // Circuit JSON inputs are own fields; prototypes may carry unrelated app metadata. for (const s of Object.keys(input)) { currentComponent = 'main'; const stack = [{ selectors: [] as string[], values: input[s] }]; while (stack.length) { const { selectors, values } = stack.pop()!; if (!Array.isArray(values)) { if (values === undefined) throw new Error('Signal not defined:' + s); ctx.setSignal(s, selectors, BigInt(values)); continue; } for (let j = values.length - 1; j >= 0; j--) { stack.push({ selectors: [...selectors, `${j}`], values: values[j] }); } } } for (let i = 0; i < circJson.nInputs; i++) { const idx = inputIdx(i); if (witness[idx] === undefined) throw new Error('Input Signal not assigned: ' + signalNames(idx)); } for (let i = 0; i < witness.length; i++) if (witness[i] === undefined) throw new Error('Signal not assigned: ' + signalNames(i)); return witness.slice(0, circJson.nVars); } finally { patcher.restore(); } }; } /** Binary coder type for `.r1cs` files. */ export type R1CSType = P.CoderType< P.StructInput<{ magic: undefined; version: number; sections: P.Values<{ header: { TAG: 'header'; data: P.StructInput<{ prime: /*elided*/ any; nWires: /*elided*/ any; nPubOut: /*elided*/ any; nPubIn: /*elided*/ any; nPrvIn: /*elided*/ any; nLables: /*elided*/ any; mConstraints: /*elided*/ any; }>; }; constraint: { TAG: 'constraint'; data: [Constraint, Constraint, Constraint][]; }; wire2label: { TAG: 'wire2label'; data: bigint[]; }; customGatesList: { TAG: 'customGatesList'; data: P.Bytes; }; customGatesApplication: { TAG: 'customGatesApplication'; data: P.Bytes; }; }>[]; }> >; /** Binary coder type for `.wtns` files. */ export type WTNSType = P.CoderType< P.StructInput<{ magic: undefined; version: number; sections: P.Values<{ header: { TAG: 'header'; data: P.StructInput<{ prime: /*elided*/ any; size: /*elided*/ any; }>; }; witness: { TAG: 'witness'; data: bigint[]; }; }>[]; }> >; type CodersOutput = { R1CS: R1CSType; binWitness: P.CoderType<bigint[]>; WTNS: WTNSType; getCircuitInfo: (bytes: Uint8Array) => CircuitInfo; ZKeyRaw: P.CoderType<any>; parseZKey: (bytes: Uint8Array) => { json: any; pkey: ProvingKey; vkey: VerificationKey }; }; /** * Binary coders and parsers for Circom2 artifacts. * @param curve - Curve pair used for field sizing and point decoding. * @returns R1CS, witness, and zkey coders plus parse helpers. * @example * Build the coders once, then use them to parse and encode Circom2 artifacts. * ```ts * const { bn254 } = await import('@noble/curves/bn254.js'); * const coders = getCoders(bn254); * const bytes = coders.binWitness.encode([1n, 2n]); * coders.binWitness.decode(bytes); * ``` */ export const getCoders = (curve: BLSCurvePair): TRet<CodersOutput> => { const field = curve.fields.Fr; // NOTE: we need to pass field here, even if bigints are variable size, they are fixed to field bytes! const fieldBytes = field.BYTES; const fieldCoder = P.bigint(fieldBytes, true, false); const Header = P.struct({ prime: P.prefix(P.U32LE, fieldCoder), // TODO: verify that exactly same as field.ORDER? nWires: P.U32LE, // Total Number of wires including ONE signal (Index 0). nPubOut: P.U32LE, // Total Number of wires public output wires. They should be starting at idx 1 nPubIn: P.U32LE, // Total Number of wires public input wires. They should be starting just after the public output nPrvIn: P.U32LE, // Total Number of wires private input wires. They should be starting just after the public inputs nLables: P.U64LE, // Total Number of wires private input wires. They should be starting just after the public inputs mConstraints: P.U32LE, // Total Number of constraints }); type ConstraintPair = [number, bigint]; const constraintDict = { encode: (from: ConstraintPair[]): Constraint => { if (!Array.isArray(from)) throw new Error('array expected'); const to: Constraint = {}; for (const item of from) { if (!Array.isArray(item) || item.length !== 2) throw new Error(`array of two elements expected`); const [key, value] = item; if (Object.prototype.hasOwnProperty.call(to, key)) throw new Error(`key(${key}) appears twice in constraint`); to[key] = value; } return to; }, decode: (to: Constraint): ConstraintPair[] => { if (to === null || typeof to !== 'object' || Array.isArray(to)) throw new Error(`expected constraint object, got ${to}`); return Object.entries(to).map(([key, value]): ConstraintPair => { // Object.entries() stringifies numeric R1CS signal ids; U32LE needs the number back. if (!/^(0|[1-9][0-9]*)$/.test(key)) throw new Error(`expected uint32 constraint key, got ${key}`); const n = Number(key); if (!Number.isSafeInteger(n) || n < 0 || n > 0xffffffff) throw new Error(`expected uint32 constraint key, got ${key}`); return [n, value]; }); }, }; const Constraint: P.CoderType<Constraint> = P.apply( P.array(P.U32LE, P.tuple([P.U32LE, fieldCoder])), constraintDict ); // A*B-C = 0 const Constraints: P.CoderType<[Constraint, Constraint, Constraint][]> = P.array( null, P.tuple([Constraint, Constraint, Constraint]) ); const WireMap = P.array(null, P.U64LE); // prefix() emits JS byte lengths, while Circom section headers serialize them as u64. const sectionLen = P.apply(P.U64LE, P.coders.numberBigint); const section = <T>(inner: P.CoderType<T>) => P.prefix(sectionLen, inner); const empty = P.bytes(null); const R1CSSection = P.mappedTag(P.U32LE, { header: [0x01, section(Header)], constraint: [0x02, section(Constraints)], wire2label: [0x03, section(WireMap)], // not implemented: ultra-plonk customGatesList: [0x04, section(empty)], customGatesApplication: [0x05, section(empty)], }); const R1CS = P.struct({ magic: P.magic(P.string(4), 'r1cs'), version: P.U32LE, sections: P.array(P.U32LE, R1CSSection), }); const binWitness = P.array(null, fieldCoder); const WTNSHeader = P.struct({ prime: P.prefix(P.U32LE, fieldCoder), size: P.U32LE, }); const WTNSSection = P.mappedTag(P.U32LE, { header: [0x01, section(WTNSHeader)], witness: [0x02, section(P.array(null, fieldCoder))], }); const WTNS = P.struct({ magic: P.magic(P.string(4), 'wtns'), version: P.U32LE, sections: P.array(P.U32LE, WTNSSection), }); const G1 = P.tuple([fieldCoder, fieldCoder]); const G2 = P.tuple([fieldCoder, fieldCoder, fieldCoder, fieldCoder]); const ZKeyHeader = P.map(P.U32LE, { groth16: 1, }); const ZKeyHeaderGroth = P.struct({ n8q: P.U32LE, q: fieldCoder, n8r: P.U32LE, r: fieldCoder, nVars: P.U32LE, nPublic: P.U32LE, domainSize: P.U32LE, vk_alpha_1: G1, vk_beta_1: G1, vk_beta_2: G2, vk_gamma_2: G2, vk_delta_1: G1, vk_delta_2: G2, }); const ZKeyCoeff = P.struct({ matrix: P.U32LE, constraint: P.U32LE, signal: P.U32LE, value: fieldCoder, }); const ZKeySection = P.mappedTag(P.U32LE, { header: [1, section(ZKeyHeader)], headerGroth: [2, section(ZKeyHeaderGroth)], IC: [3, section(P.array(null, G1))], ccoefs: [4, section(P.array(P.U32LE, ZKeyCoeff))], A: [5, section(P.array(null, G1))], B1: [6, section(P.array(null, G1))], B2: [7, section(P.array(null, G2))], C: [8, section(P.array(null, G1))], hExps: [9, section(P.array(null, G1))], Contributions: [10, section(P.bytes(null))], }); const ZKeyRaw = P.struct({ magic: P.magic(P.string(4), 'zkey'), version: P.U32LE, sections: P.array(P.U32LE, ZKeySection), }); const getCircuitInfo = (bytes: TArg<Uint8Array>): CircuitInfo => { const data = R1CS.decode(bytes); const constraints = data.sections.find((i) => i.TAG === 'constraint'); if (!constraints) throw new Error('R1CS: cannot find constraints'); const header = data.sections.find((i) => i.TAG === 'header'); if (!header) throw new Error('R1CS: cannot find header'); if (header.data.prime !== field.ORDER) throw new Error('R1CS: wrong field order'); return { nVars: header.data.nWires, nPubInputs: header.data.nPubIn, nOutputs: header.data.nPubOut, constraints: constraints.data, }; }; function parseZKey(zkey: TArg<Uint8Array>) { const { Fr, Fp } = curve.fields; // Montgomery encoding of field elements const fieldFromMont = (f: IField<bigint>, is1: boolean) => { const Rr = f.pow(BigInt(2), BigInt(f.BYTES * 8)); const RRi = f.inv(Rr); const RRi2 = f.mul(RRi, RRi); // G1/G2 coordinates carry one Montgomery factor; coefficient field elements need two. return (x: bigint) => f.mul(x, is1 ? RRi : RRi2); }; const is0 = (x: bigint) => x === BigInt(0); const convFr2 = fieldFromMont(Fr, false); const convFp = fieldFromMont(Fp, true); const convG1 = ([x, y]: bigint[]): G1Point => is0(x) && is0(y) ? [BigInt(0), BigInt(1), BigInt(0)] : [convFp(x), convFp(y), BigInt(1)]; // [ [ 0n, 0n ], [ 0n, 0n ], [ 1n, 0n ] ], -> [ [ 0n, 0n ], [ 1n, 0n ], [ 0n, 0n ] ], const convG2 = ([xc0, xc1, yc0, yc1]: bigint[]): G2Point => is0(xc0) && is0(xc1) && is0(yc0) && is0(yc1) ? [ [BigInt(0), BigInt(0)], [BigInt(1), BigInt(0)], [BigInt(0), BigInt(0)], ] : [ [convFp(xc0), convFp(xc1)], [convFp(yc0), convFp(yc1)], [BigInt(1), BigInt(0)], ]; const data = ZKeyRaw.decode(zkey); function getByTag<T extends { TAG: string; data: unknown }, K extends T['TAG']>( sections: T[], tag: K ): Extract<T, { TAG: K }>['data'] { const v = sections.find((i): i is Extract<T, { TAG: K }> => i.TAG === tag); if (!v) throw new Error('ZKey: cannot find ' + String(tag)); return v.data; } function collect<T extends { TAG: string; data: unknown }, K extends readonly T['TAG'][]>( sections: T[], ks: K ): { [P in K[number]]: Extract<T, { TAG: P }>['data'] } { const out = {} as any; for (const k of ks) out[k] = getByTag<T, typeof k>(sections, k); return out; } const res = collect(data.sections, [ 'header', 'headerGroth', 'IC', 'ccoefs', 'A', 'B1', 'B2', 'C', 'hExps', ] as const); // Same format as verification key const json = { protocol: res.header, ...res.headerGroth, vk_alpha_1: convG1(res.headerGroth.vk_alpha_1), vk_beta_1: convG1(res.headerGroth.vk_beta_1), vk_delta_1: convG1(res.headerGroth.vk_delta_1), vk_beta_2: convG2(res.headerGroth.vk_beta_2), vk_delta_2: convG2(res.headerGroth.vk_delta_2), vk_gamma_2: convG2(res.headerGroth.vk_gamma_2), power: Math.log2(res.headerGroth.domainSize), IC: res.IC.map(convG1), ccoefs: res.ccoefs.map((i) => ({ ...i, value: convFr2(i.value) })), A: res.A.map(convG1), B1: res.B1.map(convG1), B2: res.B2.map(convG2), // snarkjs zkeys omit the leading zero C-query entries for public signals. C: new Array(res.headerGroth.nPublic + 1).fill(null).concat(res.C.map(convG1)), hExps: res.hExps.map(convG1), }; // Our format (old snarkjs compat) const pkey: ProvingKey = { protocol: 'groth', nVars: json.nVars, nPublic: json.nPublic, domainSize: json.domainSize, domainBits: json.power, // Polynominals (instead polsA/polsB/polsC) ccoefs: json.ccoefs, // changed // A: json.A, B1: json.B1, B2: json.B2, C: json.C, // vk_alfa_1: json.vk_alpha_1, vk_beta_1: json.vk_beta_1, vk_delta_1: json.vk_delta_1, vk_beta_2: json.vk_beta_2, vk_delta_2: json.vk_delta_2, // hExps: json.hExps, }; const vkey: VerificationKey = { protocol: 'groth', nPublic: json.nPublic, IC: json.IC, vk_alfa_1: json.vk_alpha_1, vk_beta_2: json.vk_beta_2, vk_gamma_2: json.vk_gamma_2, vk_delta_2: json.vk_delta_2, }; return { json, pkey, vkey }; } return { R1CS, binWitness, WTNS, getCircuitInfo, ZKeyRaw, parseZKey } as TRet<CodersOutput>; };