o1js
Version:
TypeScript framework for zk-SNARKs and zkApps
346 lines (290 loc) • 10 kB
text/typescript
import { Provable } from '../provable.js';
import { Fp } from '../../../bindings/crypto/finite-field.js';
import { Field } from '../field.js';
import { Gates } from '../gates.js';
import { assert, divideWithRemainder, toVar, bitSlice } from './common.js';
import { rangeCheck32, rangeCheck64 } from './range-check.js';
import { divMod32 } from './arithmetic.js';
import { exists } from '../../provable/core/exists.js';
export {
xor,
not,
rotate64,
rotate32,
and,
rightShift64,
leftShift64,
leftShift32,
};
function not(a: Field, length: number, checked: boolean = false) {
// check that input length is positive
assert(length > 0, `Input length needs to be positive values.`);
// Check that length does not exceed maximum field size in bits
assert(
length < Field.sizeInBits,
`Length ${length} exceeds maximum of ${Field.sizeInBits} bits.`
);
// obtain pad length until the length is a multiple of 16 for n-bit length lookup table
let padLength = Math.ceil(length / 16) * 16;
// handle constant case
if (a.isConstant()) {
let max = 1n << BigInt(padLength);
assert(
a.toBigInt() < max,
`${a.toBigInt()} does not fit into ${padLength} bits`
);
return new Field(Fp.not(a.toBigInt(), length));
}
// create a bitmask with all ones
let allOnes = new Field(2n ** BigInt(length) - 1n);
if (checked) {
return xor(a, allOnes, length);
} else {
return allOnes.sub(a).seal();
}
}
function xor(a: Field, b: Field, length: number) {
// check that both input lengths are positive
assert(length > 0, `Input lengths need to be positive values.`);
// check that length does not exceed maximum 254 size in bits
assert(length <= 254, `Length ${length} exceeds maximum of 254 bits.`);
// obtain pad length until the length is a multiple of 16 for n-bit length lookup table
let padLength = Math.ceil(length / 16) * 16;
// handle constant case
if (a.isConstant() && b.isConstant()) {
let max = 1n << BigInt(padLength);
assert(a.toBigInt() < max, `${a} does not fit into ${padLength} bits`);
assert(b.toBigInt() < max, `${b} does not fit into ${padLength} bits`);
return new Field(a.toBigInt() ^ b.toBigInt());
}
// calculate expected xor output
let outputXor = Provable.witness(Field, () => a.toBigInt() ^ b.toBigInt());
// builds the xor gadget chain
buildXor(a, b, outputXor, padLength);
// return the result of the xor operation
return outputXor;
}
// builds a xor chain
function buildXor(a: Field, b: Field, out: Field, padLength: number) {
// construct the chain of XORs until padLength is 0
while (padLength !== 0) {
// slices the inputs into 4x 4bit-sized chunks
let slices = exists(15, () => {
let a0 = a.toBigInt();
let b0 = b.toBigInt();
let out0 = out.toBigInt();
return [
// slices of a
bitSlice(a0, 0, 4),
bitSlice(a0, 4, 4),
bitSlice(a0, 8, 4),
bitSlice(a0, 12, 4),
// slices of b
bitSlice(b0, 0, 4),
bitSlice(b0, 4, 4),
bitSlice(b0, 8, 4),
bitSlice(b0, 12, 4),
// slices of expected output
bitSlice(out0, 0, 4),
bitSlice(out0, 4, 4),
bitSlice(out0, 8, 4),
bitSlice(out0, 12, 4),
// next values
a0 >> 16n,
b0 >> 16n,
out0 >> 16n,
];
});
// prettier-ignore
let [
in1_0, in1_1, in1_2, in1_3,
in2_0, in2_1, in2_2, in2_3,
out0, out1, out2, out3,
aNext, bNext, outNext
] = slices;
// assert that the xor of the slices is correct, 16 bit at a time
// prettier-ignore
Gates.xor(
a, b, out,
in1_0, in1_1, in1_2, in1_3,
in2_0, in2_1, in2_2, in2_3,
out0, out1, out2, out3
);
// update the values for the next loop iteration
a = aNext;
b = bNext;
out = outNext;
padLength = padLength - 16;
}
// inputs are zero and length is zero, add the zero check - we reached the end of our chain
Gates.zero(a, b, out);
let zero = new Field(0);
zero.assertEquals(a);
zero.assertEquals(b);
zero.assertEquals(out);
}
function and(a: Field, b: Field, length: number) {
// check that both input lengths are positive
assert(length > 0, `Input lengths need to be positive values.`);
// check that length does not exceed maximum field size in bits
assert(
length <= Field.sizeInBits,
`Length ${length} exceeds maximum of ${Field.sizeInBits} bits.`
);
// obtain pad length until the length is a multiple of 16 for n-bit length lookup table
let padLength = Math.ceil(length / 16) * 16;
// handle constant case
if (a.isConstant() && b.isConstant()) {
let max = 1n << BigInt(padLength);
assert(a.toBigInt() < max, `${a} does not fit into ${padLength} bits`);
assert(b.toBigInt() < max, `${b} does not fit into ${padLength} bits`);
return new Field(a.toBigInt() & b.toBigInt());
}
// calculate expect and output
let outputAnd = Provable.witness(Field, () => a.toBigInt() & b.toBigInt());
// compute values for gate
// explanation: https://o1-labs.github.io/proof-systems/specs/kimchi.html?highlight=gates#and
let sum = a.add(b);
let xorOutput = xor(a, b, length);
outputAnd.mul(2).add(xorOutput).assertEquals(sum);
// return the result of the and operation
return outputAnd;
}
function rotate64(
field: Field,
bits: number,
direction: 'left' | 'right' = 'left'
) {
// Check that the rotation bits are in range
assert(
bits >= 0 && bits <= 64,
`rotation: expected bits to be between 0 and 64, got ${bits}`
);
if (field.isConstant()) {
assert(
field.toBigInt() < 1n << 64n,
`rotation: expected field to be at most 64 bits, got ${field.toBigInt()}`
);
return new Field(Fp.rot(field.toBigInt(), BigInt(bits), direction));
}
const [rotated] = rot64(field, bits, direction);
return rotated;
}
function rotate32(
field: Field,
bits: number,
direction: 'left' | 'right' = 'left'
) {
assert(bits <= 32 && bits > 0, 'bits must be between 0 and 32');
if (field.isConstant()) {
assert(
field.toBigInt() < 1n << 32n,
`rotation: expected field to be at most 32 bits, got ${field.toBigInt()}`
);
return new Field(Fp.rot(field.toBigInt(), BigInt(bits), direction, 32n));
}
let { quotient: excess, remainder: shifted } = divMod32(
field.mul(1n << BigInt(direction === 'left' ? bits : 32 - bits))
);
let rotated = shifted.add(excess);
rangeCheck32(rotated);
return rotated;
}
function rot64(
field: Field,
bits: number,
direction: 'left' | 'right' = 'left'
): [Field, Field, Field] {
const rotationBits = direction === 'right' ? 64 - bits : bits;
const big2Power64 = 1n << 64n;
const big2PowerRot = 1n << BigInt(rotationBits);
const [rotated, excess, shifted, bound] = Provable.witness(
Provable.Array(Field, 4),
() => {
const f = field.toBigInt();
// Obtain rotated output, excess, and shifted for the equation:
// f * 2^rot = excess * 2^64 + shifted
const { quotient: excess, remainder: shifted } = divideWithRemainder(
f * big2PowerRot,
big2Power64
);
// Compute rotated value as: rotated = excess + shifted
const rotated = shifted + excess;
// Compute bound to check excess < 2^rot
const bound = excess + big2Power64 - big2PowerRot;
return [rotated, excess, shifted, bound];
}
);
// flush zero var to prevent broken gate chain (zero is used in rangeCheck64)
// TODO this is an abstraction leak, but not clear to me how to improve
toVar(0n);
// slice the bound into chunks
let boundSlices = exists(12, () => {
let bound0 = bound.toBigInt();
return [
bitSlice(bound0, 52, 12), // bits 52-64
bitSlice(bound0, 40, 12), // bits 40-52
bitSlice(bound0, 28, 12), // bits 28-40
bitSlice(bound0, 16, 12), // bits 16-28
bitSlice(bound0, 14, 2), // bits 14-16
bitSlice(bound0, 12, 2), // bits 12-14
bitSlice(bound0, 10, 2), // bits 10-12
bitSlice(bound0, 8, 2), // bits 8-10
bitSlice(bound0, 6, 2), // bits 6-8
bitSlice(bound0, 4, 2), // bits 4-6
bitSlice(bound0, 2, 2), // bits 2-4
bitSlice(bound0, 0, 2), // bits 0-2
];
});
let [b52, b40, b28, b16, b14, b12, b10, b8, b6, b4, b2, b0] = boundSlices;
// Compute current row
Gates.rotate(
field,
rotated,
excess,
[b52, b40, b28, b16],
[b14, b12, b10, b8, b6, b4, b2, b0],
big2PowerRot
);
// Compute next row
rangeCheck64(shifted);
// note: range-checking `shifted` and `field` is enough.
// * excess < 2^rot follows from the bound check and the rotation equation in the gate
// * rotated < 2^64 follows from rotated = excess + shifted (because shifted has to be a multiple of 2^rot)
// for a proof, see https://github.com/o1-labs/o1js/pull/1201
return [rotated, excess, shifted];
}
function rightShift64(field: Field, bits: number) {
assert(
bits >= 0 && bits <= 64,
`rightShift: expected bits to be between 0 and 64, got ${bits}`
);
if (field.isConstant()) {
assert(
field.toBigInt() < 1n << 64n,
`rightShift: expected field to be at most 64 bits, got ${field.toBigInt()}`
);
return new Field(Fp.rightShift(field.toBigInt(), bits));
}
const [, excess] = rot64(field, bits, 'right');
return excess;
}
function leftShift64(field: Field, bits: number) {
assert(
bits >= 0 && bits <= 64,
`rightShift: expected bits to be between 0 and 64, got ${bits}`
);
if (field.isConstant()) {
assert(
field.toBigInt() < 1n << 64n,
`rightShift: expected field to be at most 64 bits, got ${field.toBigInt()}`
);
return new Field(Fp.leftShift(field.toBigInt(), bits));
}
const [, , shifted] = rot64(field, bits, 'left');
return shifted;
}
function leftShift32(field: Field, bits: number) {
let { remainder: shifted } = divMod32(field.mul(1n << BigInt(bits)));
return shifted;
}