micro-zk-proofs
Version:
Create & verify zero-knowledge SNARK proofs in parallel, using noble cryptography
319 lines • 13.1 kB
JavaScript
/**
* Minimal witness program executor for circom programs, based on websnark/wasmsnark/snarkjs.
* Unsafe: it uses eval, better to be used inside worker threads.
* Depends on **monkey-patched BigInt** prototypes due to how circom programs are serialized.
* We only patch prototypes before execution. After finishing, patches are reverted.
* This way, no other code can interfere with it.
* @module
*/
import { invert, pow } from '@noble/curves/abstract/modular';
import { bn254 as nobleBn254 } from '@noble/curves/bn254';
import * as P from 'micro-packed';
import {} from './index.js';
function monkeyPatchBigInt() {
const methods = {
// Equality
eq: (a, b) => a === b,
neq: (a, b) => a !== b,
greaterOrEquals: (a, b) => a >= b,
greater: (a, b) => a > b,
gt: (a, b) => a > b,
lesserOrEquals: (a, b) => a <= b,
lesser: (a, b) => a < b,
lt: (a, b) => a < b,
// Basic math
sub: (a, b) => a - b,
add: (a, b) => a + b,
mul: (a, b) => a * b,
div: (a, b) => a / b,
mod: (a, b) => a % b,
// Fields
inverse: (n, modulo) => invert(n, modulo),
modPow: (a, power, modulo) => pow(a, power, modulo),
// Binary
and: (a, b) => a & b,
shr: (a, b) => a >> BigInt(b),
};
let patched = false;
let orig = {};
const proto = BigInt.prototype;
return {
patch() {
if (patched)
throw new Error('bigint: already patched');
for (const name in methods) {
orig[name] = proto[name];
proto[name] = function (...args) {
return methods[name](this, ...args);
};
}
patched = true;
},
restore() {
if (!patched)
throw new Error('bigint: not patched');
for (const name in methods) {
if (orig[name] === undefined)
delete proto[name];
else
proto[name] = orig[name];
}
orig = {};
patched = false;
},
};
}
const selectorStr = (lst) => lst.map((i) => `[${i}]`).join('');
const signalStr = (name, selectors) => name + selectorStr(selectors);
// Apply selectors
const select = (a, selectors) => {
for (const s of selectors)
a = a[s];
return a;
};
export function generateWitness(circJson) {
const P = nobleBn254.fields.Fr.ORDER;
const MASK = nobleBn254.fields.Fr.MASK;
const signals = circJson.signals;
const components = circJson.components;
const templates = {};
// Bind P & MASK directly into templates/functions, so we see dependency
for (let t in circJson.templates) {
templates[t] = new Function('bigInt', '__P__', '__MASK__', 'return ' + circJson.templates[t])(BigInt, P, MASK);
}
const functions = {};
for (let f in circJson.functions) {
functions[f] = {
params: circJson.functions[f].params,
func: new Function('bigInt', '__P__', '__MASK__', 'return ' + circJson.functions[f].func)(BigInt, P, MASK),
};
}
function inputIdx(i) {
if (i >= circJson.nInputs)
throw new Error('Accessing an invalid input: ' + i);
return circJson.nOutputs + 1 + i;
}
function getSignalIdx(name) {
if (circJson.signalName2Idx[name] !== undefined)
return circJson.signalName2Idx[name];
if (!isNaN(name))
return Number(name);
throw new Error('Invalid signal identifier: ' + name);
}
const signalNames = (i) => signals[getSignalIdx(i)].names.join(', ');
const patcher = monkeyPatchBigInt();
return function (input) {
patcher.patch();
const witness = new Array(circJson.nSignals);
let currentComponent;
let scopes = []; // scope stack
const notInitSignals = {};
function inScope(newScope, cb) {
const oldScope = scopes;
scopes = [scopes[0], newScope];
const res = cb();
scopes = oldScope;
return res;
}
function triggerComponent(c) {
notInitSignals[c]--;
const oldComponent = currentComponent;
currentComponent = components[c].name;
const template = components[c].template;
const newScope = {};
for (let p in components[c].params)
newScope[p] = components[c].params[p];
inScope(newScope, () => templates[template](ctx));
currentComponent = oldComponent;
}
function setSignalFullName(fullName, value) {
const sId = getSignalIdx(fullName);
let firstInit = false;
if (witness[sId] === undefined)
firstInit = true;
witness[sId] = BigInt(value);
const callComponents = [];
for (let i = 0; i < signals[sId].triggerComponents.length; i++) {
var idCmp = signals[sId].triggerComponents[i];
if (firstInit)
notInitSignals[idCmp]--;
callComponents.push(idCmp);
}
callComponents.map((c) => {
if (notInitSignals[c] == 0)
triggerComponent(c);
});
return witness[sId];
}
function getSignalFullName(name) {
const id = getSignalIdx(name);
if (witness[id] === undefined)
throw new Error('Signal not initialized: ' + name);
return witness[id];
}
const cName = (name) => (name == 'one' ? 'one' : currentComponent + '.' + name);
// Minimal API that used inside evaluated code
const ctx = {
// Pins
setPin(compName, compSel, sigName, sigSel, value) {
const name = signalStr(cName(compName), compSel) + '.' + signalStr(sigName, sigSel);
setSignalFullName(name, value);
},
getPin(compName, componentSels, sigName, sigSel) {
const name = signalStr(cName(compName), componentSels) + '.' + signalStr(sigName, sigSel);
return getSignalFullName(name);
},
// Vars
setVar(name, sels, value) {
const scope = scopes[scopes.length - 1];
if (sels.length == 0) {
scope[name] = value;
}
else {
if (scope[name] === undefined)
scope[name] = [];
// TODO: replace with iterative version
function setVarArray(a, sels2, value) {
if (sels2.length == 1) {
a[sels2[0]] = value;
}
else {
if (a[sels2[0]] === undefined)
a[sels2[0]] = [];
setVarArray(a[sels2[0]], sels2.slice(1), value);
}
}
setVarArray(scope[name], sels, value);
}
return value;
},
getVar(name, sels) {
for (let i = scopes.length - 1; i >= 0; i--)
if (scopes[i][name] !== undefined)
return select(scopes[i][name], sels);
throw new Error('Variable not defined: ' + name);
},
// Signals
setSignal(name, sels, value) {
setSignalFullName(signalStr(currentComponent ? currentComponent + '.' + name : name, sels), value);
},
getSignal(name, sels) {
return getSignalFullName(signalStr(cName(name), sels));
},
// Utils
callFunction(name, params) {
const newScope = {};
for (let p = 0; p < functions[name].params.length; p++)
newScope[functions[name].params[p]] = params[p];
return inScope(newScope, () => functions[name].func(ctx));
},
assert(a, b, errStr = '') {
a = BigInt(a);
b = BigInt(b);
if (a === b)
return;
throw new Error(`Constraint doesn't match ${currentComponent}: ${errStr} -> ${a} != ${b}`);
},
};
// Processing
for (const c in components)
notInitSignals[c] = components[c].inputSignals;
ctx.setSignal('one', [], BigInt(1));
for (let c in notInitSignals)
if (notInitSignals[c] == 0)
triggerComponent(c);
for (let s in input) {
currentComponent = 'main';
// Recursively iterates program and with scope stack
function iterate(values, selectors, cb) {
if (!Array.isArray(values))
return cb(selectors, values);
for (let i = 0; i < values.length; i++)
iterate(values[i], [...selectors, i], cb);
}
iterate(input[s], [], (selector, value) => {
if (value === undefined)
throw new Error('Signal not defined:' + s);
ctx.setSignal(s, selector, BigInt(value));
});
}
for (let i = 0; i < circJson.nInputs; i++) {
const idx = inputIdx(i);
if (witness[idx] === undefined)
throw new Error('Input Signal not assigned: ' + signalNames(idx));
}
for (let i = 0; i < witness.length; i++)
if (witness[i] === undefined)
throw new Error('Signal not assigned: ' + signalNames(i));
patcher.restore();
return witness.slice(0, circJson.nVars);
};
}
/** Binary coders for Circom2 */
export const getCoders = (field) => {
// NOTE: we need to pass field here, even if bigints are variable size, they are fixed to field bytes!
const fieldBytes = field.BYTES;
const fieldCoder = P.bigint(fieldBytes, true, false);
const Header = P.struct({
prime: P.prefix(P.U32LE, fieldCoder), // TODO: verify that exactly same as field.ORDER?
nWires: P.U32LE, // Total Number of wires including ONE signal (Index 0).
nPubOut: P.U32LE, // Total Number of wires public output wires. They should be starting at idx 1
nPubIn: P.U32LE, // Total Number of wires public input wires. They should be starting just after the public output
nPrvIn: P.U32LE, // Total Number of wires private input wires. They should be starting just after the public inputs
nLables: P.U64LE, // Total Number of wires private input wires. They should be starting just after the public inputs
mConstraints: P.U32LE, // Total Number of constraints
});
const Constraint = P.apply(P.array(P.U32LE, P.tuple([P.U32LE, fieldCoder])), P.coders.dict() // TODO: dict key is string, not number
);
// A*B-C = 0
const Constraints = P.array(null, P.tuple([Constraint, Constraint, Constraint]));
const WireMap = P.array(null, P.U64LE);
const section = (inner) => P.prefix(P.U64LE, inner);
const empty = P.bytes(null);
const R1CSSection = P.mappedTag(P.U32LE, {
header: [0x01, section(Header)],
constraint: [0x02, section(Constraints)],
wire2label: [0x03, section(WireMap)],
// not implemented: ultra-plonk
customGatesList: [0x04, section(empty)],
customGatesApplication: [0x05, section(empty)],
});
const R1CS = P.struct({
magic: P.magic(P.string(4), 'r1cs'),
version: P.U32LE,
sections: P.array(P.U32LE, R1CSSection),
});
const binWitness = P.array(null, fieldCoder);
const WTNSHeader = P.struct({
prime: P.prefix(P.U32LE, fieldCoder),
size: P.U32LE,
});
const WTNSSection = P.mappedTag(P.U32LE, {
header: [0x01, section(WTNSHeader)],
witness: [0x02, section(P.array(null, fieldCoder))],
});
const WTNS = P.struct({
magic: P.magic(P.string(4), 'wtns'),
version: P.U32LE,
sections: P.array(P.U32LE, WTNSSection),
});
const getCircuitInfo = (bytes) => {
const data = R1CS.decode(bytes);
const constraints = data.sections.find((i) => i.TAG === 'constraint');
if (!constraints)
throw new Error('R1CS: cannot find constraints');
const header = data.sections.find((i) => i.TAG === 'header');
if (!header)
throw new Error('R1CS: cannot find header');
if (header.data.prime !== field.ORDER)
throw new Error('R1CS: wrong field order');
return {
nVars: header.data.nWires,
nPubInputs: header.data.nPubIn,
nOutputs: header.data.nPubOut,
constraints: constraints.data,
};
};
return { R1CS, binWitness, WTNS, getCircuitInfo };
};
//# sourceMappingURL=witness.js.map