@reclaimprotocol/zk-symmetric-crypto
Version:
JS Wrappers for Various ZK Snark Circuits
160 lines (159 loc) • 6 kB
JavaScript
import { concatenateUint8Arrays } from '@reclaimprotocol/tls';
import { CONFIG } from "./config.js";
import { ceilToBlockSizeMultiple, getBlockSizeBytes, getCounterForByteOffset, splitCiphertextToBlocks } from "./utils.js";
/**
* Generate ZK proof for CHACHA20-CTR encryption.
* Circuit proves that the ciphertext is a
* valid encryption of the given plaintext.
* The plaintext can be partially redacted.
*/
export async function generateProof(opts) {
const { algorithm, operator, logger } = opts;
const { witness, plaintextArray } = await generateZkWitness(opts);
let wtnsSerialised;
if ('mask' in opts) {
wtnsSerialised = await operator.generateWitness({
...witness,
toprf: opts.toprf,
mask: opts.mask,
});
}
else {
// @ts-expect-error
wtnsSerialised = await operator.generateWitness(witness);
}
const { proof } = await operator.groth16Prove(wtnsSerialised, logger);
return {
algorithm,
proofData: proof,
plaintext: 'mask' in opts ? undefined : plaintextArray
};
}
/**
* Verify a ZK proof for CHACHA20-CTR encryption.
*
* @param proofs JSON proof generated by "generateProof"
* @param publicInput
* @param zkey
*/
export async function verifyProof(opts) {
const publicSignals = await getPublicSignals({
algorithm: opts.proof.algorithm,
plaintext: opts.proof.plaintext,
publicInput: opts.publicInput,
});
const { proof: { proofData }, operator, logger } = opts;
let verified;
if ('toprf' in opts) {
verified = await operator.groth16Verify({ ...publicSignals, toprf: opts.toprf }, proofData, logger);
}
else {
// serialise to array of numbers for the ZK circuit
verified = await operator.groth16Verify(
// @ts-expect-error
publicSignals, proofData, logger);
}
if (!verified) {
throw new Error('invalid proof');
}
}
/**
* Generate a ZK witness for the symmetric encryption circuit.
* This witness can then be used to generate a ZK proof,
* using the operator's groth16Prove function.
*/
export async function generateZkWitness({ algorithm, privateInput: { key }, publicInput, }) {
const { keySizeBytes } = CONFIG[algorithm];
if (key.length !== keySizeBytes) {
throw new Error(`key must be ${keySizeBytes} bytes`);
}
const witness = {
key,
...await getPublicSignals({ publicInput, algorithm, key })
};
return { witness, plaintextArray: witness.out };
}
export async function getPublicSignals({ publicInput, algorithm, ...opts }) {
const { ivSizeBytes } = CONFIG[algorithm];
const ciphertextBlocks = [];
const plaintextBlocks = [];
const noncesAndCounters = [];
const blockSize = getBlockSizeBytes(algorithm);
const expSize = getExpectedChunkSizeBytes(algorithm);
publicInput = Array.isArray(publicInput) ? publicInput : [publicInput];
if (!publicInput.length) {
throw new Error('at least one public input is required');
}
for (const [i, { ciphertext, iv, offsetBytes = 0 }] of publicInput.entries()) {
const blocks = splitCiphertextToBlocks(algorithm, ciphertext, iv);
for (const block of blocks) {
await addCiphertextBlock({ ...block, offsetBytes: offsetBytes + (block.offsetBytes || 0) });
}
if (i < publicInput.length - 1) {
continue;
}
const bytesDone = ciphertextBlocks.reduce((a, b) => a + b.length, 0);
if (bytesDone >= expSize) {
continue;
}
const padding = expSize - bytesDone;
const offset = offsetBytes
+ ceilToBlockSizeMultiple(ciphertext.length, algorithm);
for (let i = 0; i < padding; i += blockSize) {
await addCiphertextBlock({ ciphertext: new Uint8Array(), iv, offsetBytes: offset + i });
}
}
const pubSigs = {
noncesAndCounters,
in: concatenateUint8Arrays(ciphertextBlocks),
out: 'plaintext' in opts && opts.plaintext
? opts.plaintext
: concatenateUint8Arrays(plaintextBlocks),
};
if (pubSigs.in.length !== getExpectedChunkSizeBytes(algorithm)) {
throw new Error(`Ciphertext must be exactly ${expSize}b, got ${pubSigs.in.length}b`);
}
return pubSigs;
async function addCiphertextBlock({ ciphertext, iv, offsetBytes = 0 }) {
if (iv.length !== ivSizeBytes) {
throw new Error(`iv must be ${ivSizeBytes} bytes`);
}
const startCounter = getCounterForByteOffset(algorithm, offsetBytes);
noncesAndCounters.push({ nonce: iv, counter: startCounter, boundary: undefined });
ciphertext = padCiphertextToSize(ciphertext, blockSize);
ciphertextBlocks.push(ciphertext);
if ('key' in opts) {
const plaintextArray = await decryptCiphertext({
algorithm,
key: opts.key,
iv,
startOffset: offsetBytes,
ciphertext: ciphertext,
});
plaintextBlocks.push(plaintextArray);
}
}
}
function padCiphertextToSize(ciphertext, size) {
if (ciphertext.length > size) {
throw new Error(`ciphertext must be <= ${size}b`);
}
if (ciphertext.length < size) {
const arr = new Uint8Array(size);
arr.set(ciphertext);
ciphertext = arr;
}
return ciphertext;
}
function getExpectedChunkSizeBytes(alg) {
const { blocksPerChunk } = CONFIG[alg];
return getBlockSizeBytes(alg) * blocksPerChunk;
}
async function decryptCiphertext({ algorithm, key, iv, startOffset, ciphertext, }) {
const { encrypt } = CONFIG[algorithm];
// fake the start of the ciphertext (it's irrelevant)
const inp = new Uint8Array(startOffset + ciphertext.length);
inp.set(ciphertext, startOffset);
const out = await encrypt({ key, iv, in: inp });
return out.slice(startOffset);
}