UNPKG

@reclaimprotocol/zk-symmetric-crypto

Version:
160 lines (159 loc) 6 kB
import { concatenateUint8Arrays } from '@reclaimprotocol/tls'; import { CONFIG } from "./config.js"; import { ceilToBlockSizeMultiple, getBlockSizeBytes, getCounterForByteOffset, splitCiphertextToBlocks } from "./utils.js"; /** * Generate ZK proof for CHACHA20-CTR encryption. * Circuit proves that the ciphertext is a * valid encryption of the given plaintext. * The plaintext can be partially redacted. */ export async function generateProof(opts) { const { algorithm, operator, logger } = opts; const { witness, plaintextArray } = await generateZkWitness(opts); let wtnsSerialised; if ('mask' in opts) { wtnsSerialised = await operator.generateWitness({ ...witness, toprf: opts.toprf, mask: opts.mask, }); } else { // @ts-expect-error wtnsSerialised = await operator.generateWitness(witness); } const { proof } = await operator.groth16Prove(wtnsSerialised, logger); return { algorithm, proofData: proof, plaintext: 'mask' in opts ? undefined : plaintextArray }; } /** * Verify a ZK proof for CHACHA20-CTR encryption. * * @param proofs JSON proof generated by "generateProof" * @param publicInput * @param zkey */ export async function verifyProof(opts) { const publicSignals = await getPublicSignals({ algorithm: opts.proof.algorithm, plaintext: opts.proof.plaintext, publicInput: opts.publicInput, }); const { proof: { proofData }, operator, logger } = opts; let verified; if ('toprf' in opts) { verified = await operator.groth16Verify({ ...publicSignals, toprf: opts.toprf }, proofData, logger); } else { // serialise to array of numbers for the ZK circuit verified = await operator.groth16Verify( // @ts-expect-error publicSignals, proofData, logger); } if (!verified) { throw new Error('invalid proof'); } } /** * Generate a ZK witness for the symmetric encryption circuit. * This witness can then be used to generate a ZK proof, * using the operator's groth16Prove function. */ export async function generateZkWitness({ algorithm, privateInput: { key }, publicInput, }) { const { keySizeBytes } = CONFIG[algorithm]; if (key.length !== keySizeBytes) { throw new Error(`key must be ${keySizeBytes} bytes`); } const witness = { key, ...await getPublicSignals({ publicInput, algorithm, key }) }; return { witness, plaintextArray: witness.out }; } export async function getPublicSignals({ publicInput, algorithm, ...opts }) { const { ivSizeBytes } = CONFIG[algorithm]; const ciphertextBlocks = []; const plaintextBlocks = []; const noncesAndCounters = []; const blockSize = getBlockSizeBytes(algorithm); const expSize = getExpectedChunkSizeBytes(algorithm); publicInput = Array.isArray(publicInput) ? publicInput : [publicInput]; if (!publicInput.length) { throw new Error('at least one public input is required'); } for (const [i, { ciphertext, iv, offsetBytes = 0 }] of publicInput.entries()) { const blocks = splitCiphertextToBlocks(algorithm, ciphertext, iv); for (const block of blocks) { await addCiphertextBlock({ ...block, offsetBytes: offsetBytes + (block.offsetBytes || 0) }); } if (i < publicInput.length - 1) { continue; } const bytesDone = ciphertextBlocks.reduce((a, b) => a + b.length, 0); if (bytesDone >= expSize) { continue; } const padding = expSize - bytesDone; const offset = offsetBytes + ceilToBlockSizeMultiple(ciphertext.length, algorithm); for (let i = 0; i < padding; i += blockSize) { await addCiphertextBlock({ ciphertext: new Uint8Array(), iv, offsetBytes: offset + i }); } } const pubSigs = { noncesAndCounters, in: concatenateUint8Arrays(ciphertextBlocks), out: 'plaintext' in opts && opts.plaintext ? opts.plaintext : concatenateUint8Arrays(plaintextBlocks), }; if (pubSigs.in.length !== getExpectedChunkSizeBytes(algorithm)) { throw new Error(`Ciphertext must be exactly ${expSize}b, got ${pubSigs.in.length}b`); } return pubSigs; async function addCiphertextBlock({ ciphertext, iv, offsetBytes = 0 }) { if (iv.length !== ivSizeBytes) { throw new Error(`iv must be ${ivSizeBytes} bytes`); } const startCounter = getCounterForByteOffset(algorithm, offsetBytes); noncesAndCounters.push({ nonce: iv, counter: startCounter, boundary: undefined }); ciphertext = padCiphertextToSize(ciphertext, blockSize); ciphertextBlocks.push(ciphertext); if ('key' in opts) { const plaintextArray = await decryptCiphertext({ algorithm, key: opts.key, iv, startOffset: offsetBytes, ciphertext: ciphertext, }); plaintextBlocks.push(plaintextArray); } } } function padCiphertextToSize(ciphertext, size) { if (ciphertext.length > size) { throw new Error(`ciphertext must be <= ${size}b`); } if (ciphertext.length < size) { const arr = new Uint8Array(size); arr.set(ciphertext); ciphertext = arr; } return ciphertext; } function getExpectedChunkSizeBytes(alg) { const { blocksPerChunk } = CONFIG[alg]; return getBlockSizeBytes(alg) * blocksPerChunk; } async function decryptCiphertext({ algorithm, key, iv, startOffset, ciphertext, }) { const { encrypt } = CONFIG[algorithm]; // fake the start of the ciphertext (it's irrelevant) const inp = new Uint8Array(startOffset + ciphertext.length); inp.set(ciphertext, startOffset); const out = await encrypt({ key, iv, in: inp }); return out.slice(startOffset); }