UNPKG

snarky-smt

Version:

Sparse Merkle Tree for SnarkyJS

222 lines (221 loc) 8.51 kB
import { Circuit, Field, Poseidon, Struct } from 'snarkyjs'; import { EMPTY_VALUE, SMT_DEPTH } from '../constant'; import { SparseMerkleProof } from './proofs'; export { ProvableDeepSparseMerkleSubTree }; class SMTSideNodes extends Struct({ arr: Circuit.array(Field, SMT_DEPTH) }) { } /** * ProvableDeepSparseMerkleSubTree is a deep sparse merkle subtree for working on only a few leafs in circuit. * * @class ProvableDeepSparseMerkleSubTree * @template K * @template V */ class ProvableDeepSparseMerkleSubTree { /** * Creates an instance of ProvableDeepSparseMerkleSubTree. * @param {Field} root merkle root * @param {Provable<K>} keyType * @param {Provable<V>} valueType * @param {{ hasher: Hasher; hashKey: boolean; hashValue: boolean }} [options={ * hasher: Poseidon.hash, * hashKey: true, * hashValue: true, * }] hasher: The hash function to use, defaults to Poseidon.hash; hashKey: * whether to hash the key, the default is true; hashValue: whether to hash the value, * the default is true. * @memberof ProvableDeepSparseMerkleSubTree */ constructor(root, keyType, valueType, options = { hasher: Poseidon.hash, hashKey: true, hashValue: true, }) { this.root = root; this.nodeStore = new Map(); this.valueStore = new Map(); this.hasher = options.hasher; this.config = { hashKey: options.hashKey, hashValue: options.hashValue }; this.keyType = keyType; this.valueType = valueType; } /** * Get current root. * * @return {*} {Field} * @memberof ProvableDeepSparseMerkleSubTree */ getRoot() { return this.root; } /** * Get height of the tree. * * @return {*} {number} * @memberof ProvableDeepSparseMerkleSubTree */ getHeight() { return SMT_DEPTH; } getKeyField(key) { let keyFields = this.keyType.toFields(key); let keyHashOrKeyField = keyFields[0]; if (this.config.hashKey) { keyHashOrKeyField = this.hasher(keyFields); } return keyHashOrKeyField; } getValueField(value) { let valueHashOrValueField = EMPTY_VALUE; if (value) { let valueFields = this.valueType.toFields(value); valueHashOrValueField = valueFields[0]; if (this.config.hashValue) { valueHashOrValueField = this.hasher(valueFields); } } return valueHashOrValueField; } /** * Add a branch to the tree, a branch is generated by smt.prove. * * @param {SparseMerkleProof} proof * @param {K} key * @param {V} [value] * @memberof ProvableDeepSparseMerkleSubTree */ addBranch(proof, key, value) { Circuit.asProver(() => { const keyField = this.getKeyField(key); const valueField = this.getValueField(value); let updates = getUpdatesBySideNodes(proof.sideNodes, keyField, valueField, this.hasher); for (let i = 0, h = updates.length; i < h; i++) { let v = updates[i]; this.nodeStore.set(v[0].toString(), v[1]); } this.valueStore.set(keyField.toString(), valueField); }); } /** * Create a merkle proof for a key against the current root. * * @param {K} key * @return {*} {SparseMerkleProof} * @memberof ProvableDeepSparseMerkleSubTree */ prove(key) { return Circuit.witness(SparseMerkleProof, () => { const keyField = this.getKeyField(key); let pathStr = keyField.toString(); let valueHash = this.valueStore.get(pathStr); if (valueHash === undefined) { throw new Error(`The DeepSubTree does not contain a branch of the path: ${pathStr}`); } const pathBits = keyField.toBits(this.getHeight()); let sideNodes = []; let nodeHash = this.root; for (let i = 0, h = this.getHeight(); i < h; i++) { const currentValue = this.nodeStore.get(nodeHash.toString()); if (currentValue === undefined) { throw new Error('Make sure you have added the correct proof, key and value using the addBranch method'); } if (pathBits[i].toBoolean()) { sideNodes.push(currentValue[0]); nodeHash = currentValue[1]; } else { sideNodes.push(currentValue[1]); nodeHash = currentValue[0]; } } return { sideNodes, root: this.root }; }); } /** * Update a new value for a key in the tree and return the new root of the tree. * * @param {K} key * @param {V} [value] * @return {*} {Field} * @memberof ProvableDeepSparseMerkleSubTree */ update(key, value) { const path = this.getKeyField(key); const valueField = this.getValueField(value); const treeHeight = this.getHeight(); const pathBits = path.toBits(treeHeight); let sideNodesArr = Circuit.witness(SMTSideNodes, () => { let sideNodes = []; let nodeHash = this.root; for (let i = 0; i < treeHeight; i++) { const currentValue = this.nodeStore.get(nodeHash.toString()); if (currentValue === undefined) { throw new Error('Make sure you have added the correct proof, key and value using the addBranch method'); } if (pathBits[i].toBoolean()) { sideNodes.push(currentValue[0]); nodeHash = currentValue[1]; } else { sideNodes.push(currentValue[1]); nodeHash = currentValue[0]; } } return { arr: sideNodes }; }); let sideNodes = sideNodesArr.arr; const oldValueHash = Circuit.witness(Field, () => { let oldValueHash = this.valueStore.get(path.toString()); if (oldValueHash === undefined) { throw new Error('oldValueHash does not exist'); } return oldValueHash.toConstant(); }); impliedRootInCircuit(sideNodes, pathBits, oldValueHash).assertEquals(this.root); let currentHash = valueField; Circuit.asProver(() => { this.nodeStore.set(currentHash.toString(), [currentHash]); }); for (let i = this.getHeight() - 1; i >= 0; i--) { let sideNode = sideNodes[i]; let currentValue = Circuit.if(pathBits[i], [sideNode, currentHash], [currentHash, sideNode]); currentHash = this.hasher(currentValue); Circuit.asProver(() => { this.nodeStore.set(currentHash.toString(), currentValue); }); } Circuit.asProver(() => { this.valueStore.set(path.toString(), valueField); }); this.root = currentHash; return this.root; } } function impliedRootInCircuit(sideNodes, pathBits, leaf) { let impliedRoot = leaf; for (let i = SMT_DEPTH - 1; i >= 0; i--) { let sideNode = sideNodes[i]; let [left, right] = Circuit.if(pathBits[i], [sideNode, impliedRoot], [impliedRoot, sideNode]); impliedRoot = Poseidon.hash([left, right]); } return impliedRoot; } function getUpdatesBySideNodes(sideNodes, keyHashOrKeyField, valueHashOrValueField, hasher = Poseidon.hash) { let currentHash = valueHashOrValueField; let updates = []; const pathBits = keyHashOrKeyField.toBits(SMT_DEPTH); updates.push([currentHash, [currentHash]]); for (let i = SMT_DEPTH - 1; i >= 0; i--) { let node = sideNodes[i]; let currentValue = []; if (pathBits[i].toBoolean()) { currentValue = [node, currentHash]; } else { currentValue = [currentHash, node]; } currentHash = hasher(currentValue); updates.push([currentHash, currentValue]); } return updates; }