UNPKG

snarky-smt

Version:

Sparse Merkle Tree for SnarkyJS

217 lines (216 loc) 8.25 kB
import { Circuit, Field, Poseidon, Struct } from 'snarkyjs'; import { EMPTY_VALUE } from '../constant'; import { BaseMerkleProof } from './proofs'; import { ProvableMerkleTreeUtils } from './verify_circuit'; export { ProvableDeepMerkleSubTree }; /** * ProvableDeepMerkleSubTree is a deep merkle subtree for working on only a few leafs in circuit. * * @class ProvableDeepMerkleSubTree * @template V */ class ProvableDeepMerkleSubTree { /** * Creates an instance of ProvableDeepMerkleSubTree. * @param {Field} root merkle root * @param {number} height height of tree * @param {Provable<V>} valueType * @param {{ hasher?: Hasher; hashValue: boolean }} [options={ * hasher: Poseidon.hash, * hashValue: true, * }] hasher: The hash function to use, defaults to Poseidon.hash; * hashValue: whether to hash the value, he default is true. * @memberof ProvableDeepMerkleSubTree */ constructor(root, height, valueType, options = { hasher: Poseidon.hash, hashValue: true, }) { this.root = root; this.nodeStore = new Map(); this.valueStore = new Map(); this.height = height; this.hasher = Poseidon.hash; if (options.hasher !== undefined) { this.hasher = options.hasher; } this.hashValue = options.hashValue; this.valueType = valueType; } getValueField(value) { let valueHashOrValueField = EMPTY_VALUE; if (value !== undefined) { let valueFields = this.valueType.toFields(value); valueHashOrValueField = valueFields[0]; if (this.hashValue) { valueHashOrValueField = this.hasher(valueFields); } } return valueHashOrValueField; } /** * Get current root. * * @return {*} {Field} * @memberof ProvableDeepMerkleSubTree */ getRoot() { return this.root; } /** * Get height of the tree. * * @return {*} {number} * @memberof ProvableDeepMerkleSubTree */ getHeight() { return this.height; } /** * Add a branch to the tree, a branch is generated by smt.prove. * * @param {BaseMerkleProof} proof * @param {Field} index * @param {V} [value] * @memberof ProvableDeepMerkleSubTree */ addBranch(proof, index, value) { Circuit.asProver(() => { const keyField = index; const valueField = this.getValueField(value); let updates = getUpdatesBySideNodes(proof.sideNodes, keyField, valueField, this.height, this.hasher); for (let i = 0, h = updates.length; i < h; i++) { let v = updates[i]; this.nodeStore.set(v[0].toString(), v[1]); } this.valueStore.set(keyField.toString(), valueField); }); } /** * Create a merkle proof for a key against the current root. * * @param {Field} index * @return {*} {BaseMerkleProof} * @memberof ProvableDeepMerkleSubTree */ prove(index) { return Circuit.witness(BaseMerkleProof, () => { const path = index; let pathStr = path.toString(); let valueHash = this.valueStore.get(pathStr); if (valueHash === undefined) { throw new Error(`The DeepSubTree does not contain a branch of the path: ${pathStr}`); } const pathBits = path.toBits(this.height); let sideNodes = []; let nodeHash = this.root; for (let i = 0; i < this.height; i++) { const currentValue = this.nodeStore.get(nodeHash.toString()); if (currentValue === undefined) { throw new Error('Make sure you have added the correct proof, key and value using the addBranch method'); } if (pathBits[i].toBoolean()) { sideNodes.push(currentValue[0]); nodeHash = currentValue[1]; } else { sideNodes.push(currentValue[1]); nodeHash = currentValue[0]; } } class MerkleProof_ extends ProvableMerkleTreeUtils.MerkleProof(this.height) { } return new MerkleProof_(this.root, sideNodes).toConstant(); }); } /** * Update a new value for a key in the tree and return the new root of the tree. * * @param {Field} index * @param {V} [value] * @return {*} {Field} * @memberof ProvableDeepMerkleSubTree */ update(index, value) { const path = index; const pathBits = path.toBits(this.height); const valueField = this.getValueField(value); class SideNodes extends Struct({ arr: Circuit.array(Field, this.height), }) { } let fieldArr = Circuit.witness(SideNodes, () => { let sideNodes = []; let nodeHash = this.root; for (let i = 0; i < this.height; i++) { const currentValue = this.nodeStore.get(nodeHash.toString()); if (currentValue === undefined) { throw new Error('Make sure you have added the correct proof, key and value using the addBranch method'); } if (pathBits[i].toBoolean()) { sideNodes.push(currentValue[0]); nodeHash = currentValue[1]; } else { sideNodes.push(currentValue[1]); nodeHash = currentValue[0]; } } return new SideNodes({ arr: sideNodes }); }); let sideNodes = fieldArr.arr; const oldValueHash = Circuit.witness(Field, () => { let oldValueHash = this.valueStore.get(path.toString()); if (oldValueHash === undefined) { throw new Error('oldValueHash does not exist'); } return oldValueHash.toConstant(); }); impliedRootForHeightInCircuit(sideNodes, pathBits, oldValueHash, this.height).assertEquals(this.root); let currentHash = valueField; Circuit.asProver(() => { this.nodeStore.set(currentHash.toString(), [currentHash]); }); for (let i = this.height - 1; i >= 0; i--) { let sideNode = sideNodes[i]; let currentValue = Circuit.if(pathBits[i], [sideNode, currentHash], [currentHash, sideNode]); currentHash = this.hasher(currentValue); Circuit.asProver(() => { this.nodeStore.set(currentHash.toString(), currentValue); }); } Circuit.asProver(() => { this.valueStore.set(path.toString(), valueField); }); this.root = currentHash; return this.root; } } function impliedRootForHeightInCircuit(sideNodes, pathBits, leaf, height) { let impliedRoot = leaf; for (let i = height - 1; i >= 0; i--) { let sideNode = sideNodes[i]; let [left, right] = Circuit.if(pathBits[i], [sideNode, impliedRoot], [impliedRoot, sideNode]); impliedRoot = Poseidon.hash([left, right]); } return impliedRoot; } function getUpdatesBySideNodes(sideNodes, keyHashOrKeyField, valueHashOrValueField, height, hasher = Poseidon.hash) { let currentHash = valueHashOrValueField; let updates = []; const pathBits = keyHashOrKeyField.toBits(height); updates.push([currentHash, [currentHash]]); for (let i = height - 1; i >= 0; i--) { let node = sideNodes[i]; let currentValue = []; if (pathBits[i].toBoolean()) { currentValue = [node, currentHash]; } else { currentValue = [currentHash, node]; } currentHash = hasher(currentValue); updates.push([currentHash, currentValue]); } return updates; }