UNPKG

micro-zk-proofs

Version:

Create & verify zero-knowledge SNARK proofs in parallel, using noble cryptography

319 lines 13.1 kB
/** * Minimal witness program executor for circom programs, based on websnark/wasmsnark/snarkjs. * Unsafe: it uses eval, better to be used inside worker threads. * Depends on **monkey-patched BigInt** prototypes due to how circom programs are serialized. * We only patch prototypes before execution. After finishing, patches are reverted. * This way, no other code can interfere with it. * @module */ import { invert, pow } from '@noble/curves/abstract/modular'; import { bn254 as nobleBn254 } from '@noble/curves/bn254'; import * as P from 'micro-packed'; import {} from './index.js'; function monkeyPatchBigInt() { const methods = { // Equality eq: (a, b) => a === b, neq: (a, b) => a !== b, greaterOrEquals: (a, b) => a >= b, greater: (a, b) => a > b, gt: (a, b) => a > b, lesserOrEquals: (a, b) => a <= b, lesser: (a, b) => a < b, lt: (a, b) => a < b, // Basic math sub: (a, b) => a - b, add: (a, b) => a + b, mul: (a, b) => a * b, div: (a, b) => a / b, mod: (a, b) => a % b, // Fields inverse: (n, modulo) => invert(n, modulo), modPow: (a, power, modulo) => pow(a, power, modulo), // Binary and: (a, b) => a & b, shr: (a, b) => a >> BigInt(b), }; let patched = false; let orig = {}; const proto = BigInt.prototype; return { patch() { if (patched) throw new Error('bigint: already patched'); for (const name in methods) { orig[name] = proto[name]; proto[name] = function (...args) { return methods[name](this, ...args); }; } patched = true; }, restore() { if (!patched) throw new Error('bigint: not patched'); for (const name in methods) { if (orig[name] === undefined) delete proto[name]; else proto[name] = orig[name]; } orig = {}; patched = false; }, }; } const selectorStr = (lst) => lst.map((i) => `[${i}]`).join(''); const signalStr = (name, selectors) => name + selectorStr(selectors); // Apply selectors const select = (a, selectors) => { for (const s of selectors) a = a[s]; return a; }; export function generateWitness(circJson) { const P = nobleBn254.fields.Fr.ORDER; const MASK = nobleBn254.fields.Fr.MASK; const signals = circJson.signals; const components = circJson.components; const templates = {}; // Bind P & MASK directly into templates/functions, so we see dependency for (let t in circJson.templates) { templates[t] = new Function('bigInt', '__P__', '__MASK__', 'return ' + circJson.templates[t])(BigInt, P, MASK); } const functions = {}; for (let f in circJson.functions) { functions[f] = { params: circJson.functions[f].params, func: new Function('bigInt', '__P__', '__MASK__', 'return ' + circJson.functions[f].func)(BigInt, P, MASK), }; } function inputIdx(i) { if (i >= circJson.nInputs) throw new Error('Accessing an invalid input: ' + i); return circJson.nOutputs + 1 + i; } function getSignalIdx(name) { if (circJson.signalName2Idx[name] !== undefined) return circJson.signalName2Idx[name]; if (!isNaN(name)) return Number(name); throw new Error('Invalid signal identifier: ' + name); } const signalNames = (i) => signals[getSignalIdx(i)].names.join(', '); const patcher = monkeyPatchBigInt(); return function (input) { patcher.patch(); const witness = new Array(circJson.nSignals); let currentComponent; let scopes = []; // scope stack const notInitSignals = {}; function inScope(newScope, cb) { const oldScope = scopes; scopes = [scopes[0], newScope]; const res = cb(); scopes = oldScope; return res; } function triggerComponent(c) { notInitSignals[c]--; const oldComponent = currentComponent; currentComponent = components[c].name; const template = components[c].template; const newScope = {}; for (let p in components[c].params) newScope[p] = components[c].params[p]; inScope(newScope, () => templates[template](ctx)); currentComponent = oldComponent; } function setSignalFullName(fullName, value) { const sId = getSignalIdx(fullName); let firstInit = false; if (witness[sId] === undefined) firstInit = true; witness[sId] = BigInt(value); const callComponents = []; for (let i = 0; i < signals[sId].triggerComponents.length; i++) { var idCmp = signals[sId].triggerComponents[i]; if (firstInit) notInitSignals[idCmp]--; callComponents.push(idCmp); } callComponents.map((c) => { if (notInitSignals[c] == 0) triggerComponent(c); }); return witness[sId]; } function getSignalFullName(name) { const id = getSignalIdx(name); if (witness[id] === undefined) throw new Error('Signal not initialized: ' + name); return witness[id]; } const cName = (name) => (name == 'one' ? 'one' : currentComponent + '.' + name); // Minimal API that used inside evaluated code const ctx = { // Pins setPin(compName, compSel, sigName, sigSel, value) { const name = signalStr(cName(compName), compSel) + '.' + signalStr(sigName, sigSel); setSignalFullName(name, value); }, getPin(compName, componentSels, sigName, sigSel) { const name = signalStr(cName(compName), componentSels) + '.' + signalStr(sigName, sigSel); return getSignalFullName(name); }, // Vars setVar(name, sels, value) { const scope = scopes[scopes.length - 1]; if (sels.length == 0) { scope[name] = value; } else { if (scope[name] === undefined) scope[name] = []; // TODO: replace with iterative version function setVarArray(a, sels2, value) { if (sels2.length == 1) { a[sels2[0]] = value; } else { if (a[sels2[0]] === undefined) a[sels2[0]] = []; setVarArray(a[sels2[0]], sels2.slice(1), value); } } setVarArray(scope[name], sels, value); } return value; }, getVar(name, sels) { for (let i = scopes.length - 1; i >= 0; i--) if (scopes[i][name] !== undefined) return select(scopes[i][name], sels); throw new Error('Variable not defined: ' + name); }, // Signals setSignal(name, sels, value) { setSignalFullName(signalStr(currentComponent ? currentComponent + '.' + name : name, sels), value); }, getSignal(name, sels) { return getSignalFullName(signalStr(cName(name), sels)); }, // Utils callFunction(name, params) { const newScope = {}; for (let p = 0; p < functions[name].params.length; p++) newScope[functions[name].params[p]] = params[p]; return inScope(newScope, () => functions[name].func(ctx)); }, assert(a, b, errStr = '') { a = BigInt(a); b = BigInt(b); if (a === b) return; throw new Error(`Constraint doesn't match ${currentComponent}: ${errStr} -> ${a} != ${b}`); }, }; // Processing for (const c in components) notInitSignals[c] = components[c].inputSignals; ctx.setSignal('one', [], BigInt(1)); for (let c in notInitSignals) if (notInitSignals[c] == 0) triggerComponent(c); for (let s in input) { currentComponent = 'main'; // Recursively iterates program and with scope stack function iterate(values, selectors, cb) { if (!Array.isArray(values)) return cb(selectors, values); for (let i = 0; i < values.length; i++) iterate(values[i], [...selectors, i], cb); } iterate(input[s], [], (selector, value) => { if (value === undefined) throw new Error('Signal not defined:' + s); ctx.setSignal(s, selector, BigInt(value)); }); } for (let i = 0; i < circJson.nInputs; i++) { const idx = inputIdx(i); if (witness[idx] === undefined) throw new Error('Input Signal not assigned: ' + signalNames(idx)); } for (let i = 0; i < witness.length; i++) if (witness[i] === undefined) throw new Error('Signal not assigned: ' + signalNames(i)); patcher.restore(); return witness.slice(0, circJson.nVars); }; } /** Binary coders for Circom2 */ export const getCoders = (field) => { // NOTE: we need to pass field here, even if bigints are variable size, they are fixed to field bytes! const fieldBytes = field.BYTES; const fieldCoder = P.bigint(fieldBytes, true, false); const Header = P.struct({ prime: P.prefix(P.U32LE, fieldCoder), // TODO: verify that exactly same as field.ORDER? nWires: P.U32LE, // Total Number of wires including ONE signal (Index 0). nPubOut: P.U32LE, // Total Number of wires public output wires. They should be starting at idx 1 nPubIn: P.U32LE, // Total Number of wires public input wires. They should be starting just after the public output nPrvIn: P.U32LE, // Total Number of wires private input wires. They should be starting just after the public inputs nLables: P.U64LE, // Total Number of wires private input wires. They should be starting just after the public inputs mConstraints: P.U32LE, // Total Number of constraints }); const Constraint = P.apply(P.array(P.U32LE, P.tuple([P.U32LE, fieldCoder])), P.coders.dict() // TODO: dict key is string, not number ); // A*B-C = 0 const Constraints = P.array(null, P.tuple([Constraint, Constraint, Constraint])); const WireMap = P.array(null, P.U64LE); const section = (inner) => P.prefix(P.U64LE, inner); const empty = P.bytes(null); const R1CSSection = P.mappedTag(P.U32LE, { header: [0x01, section(Header)], constraint: [0x02, section(Constraints)], wire2label: [0x03, section(WireMap)], // not implemented: ultra-plonk customGatesList: [0x04, section(empty)], customGatesApplication: [0x05, section(empty)], }); const R1CS = P.struct({ magic: P.magic(P.string(4), 'r1cs'), version: P.U32LE, sections: P.array(P.U32LE, R1CSSection), }); const binWitness = P.array(null, fieldCoder); const WTNSHeader = P.struct({ prime: P.prefix(P.U32LE, fieldCoder), size: P.U32LE, }); const WTNSSection = P.mappedTag(P.U32LE, { header: [0x01, section(WTNSHeader)], witness: [0x02, section(P.array(null, fieldCoder))], }); const WTNS = P.struct({ magic: P.magic(P.string(4), 'wtns'), version: P.U32LE, sections: P.array(P.U32LE, WTNSSection), }); const getCircuitInfo = (bytes) => { const data = R1CS.decode(bytes); const constraints = data.sections.find((i) => i.TAG === 'constraint'); if (!constraints) throw new Error('R1CS: cannot find constraints'); const header = data.sections.find((i) => i.TAG === 'header'); if (!header) throw new Error('R1CS: cannot find header'); if (header.data.prime !== field.ORDER) throw new Error('R1CS: wrong field order'); return { nVars: header.data.nWires, nPubInputs: header.data.nPubIn, nOutputs: header.data.nPubOut, constraints: constraints.data, }; }; return { R1CS, binWitness, WTNS, getCircuitInfo }; }; //# sourceMappingURL=witness.js.map