o1js
Version:
TypeScript framework for zk-SNARKs and zkApps
313 lines (268 loc) • 10.3 kB
text/typescript
import type { Field } from '../field.js';
import type { Bool } from '../bool.js';
import { Fp, Fq } from '../../../bindings/crypto/finite-field.js';
import { PallasAffine } from '../../../bindings/crypto/elliptic-curve.js';
import { isOddAndHigh } from './comparison.js';
import { Field3, ForeignField } from './foreign-field.js';
import { exists } from '../core/exists.js';
import { assert, bit, bitSlice, isConstant } from './common.js';
import { l, rangeCheck64 } from './range-check.js';
import {
createBool,
createBoolUnsafe,
createField,
getField,
} from '../core/field-constructor.js';
import { Snarky } from '../../../snarky.js';
import { Provable } from '../provable.js';
import { MlPair } from '../../ml/base.js';
import { provable } from '../types/provable-derivers.js';
export {
scaleField,
fieldToShiftedScalar,
field3ToShiftedScalar,
scaleShifted,
add,
ShiftedScalar,
};
type Point = { x: Field; y: Field };
type ShiftedScalar = { lowBit: Bool; high254: Field };
/**
* Dedicated gadget to scale a point by a scalar, where the scalar is represented as a _native_ Field.
*/
function scaleField(P: Point, s: Field): Point {
// constant case
let { x, y } = P;
if (x.isConstant() && y.isConstant() && s.isConstant()) {
let sP = PallasAffine.scale(
PallasAffine.fromNonzero({ x: x.toBigInt(), y: y.toBigInt() }),
s.toBigInt()
);
return { x: createField(sP.x), y: createField(sP.y) };
}
const Field = getField();
const Point = provable({ x: Field, y: Field });
/**
* Strategy:
* - use a (1, 254) split and compute s - 2^255 with manual add-and-carry
* - use all 255 rounds of `scaleFastUnpack` for the high part
* - pass in s or a dummy replacement if s = 0, 1 (which are the disallowed values)
* - return sP or 0P = 0 or 1P = P
*/
// compute t = s + (-2^255 mod q) in (1, 254) arithmetic
let { isOdd: sLo, high: sHi } = isOddAndHigh(s);
let shift = Fq.mod(-(1n << 255n));
assert((shift & 1n) === 0n); // shift happens to be even, so we don't need to worry about a carry
let shiftHi = shift >> 1n;
let tLo = sLo;
let tHi = sHi.add(shiftHi).seal();
// tHi does not overflow:
// tHi = sHi + shiftHi < p/2 + p/2 = p
// sHi < p/2 is guaranteed by isOddAndHigh
assert(shiftHi < Fp.modulus / 2n);
// the 4 values for s not supported by `scaleFastUnpack` are q-2, q-1, 0, 1
// since s came from a `Field`, we can exclude q-2, q-1
// s = 0 or 1 iff sHi = 0
let isEdgeCase = sHi.equals(0n);
let tHiSafe = Provable.if(isEdgeCase, createField(0n), tHi);
// R = (2*(t >> 1) + 1 + 2^255)P
// also returns a 255-bit representation of tHi
let [, RMl, [, ...tHiBitsMl]] = Snarky.group.scaleFastUnpack(
[0, x.value, y.value],
[0, tHiSafe.value],
255
);
let R = { x: createField(RMl[1]), y: createField(RMl[2]) };
// prove that tHi has only 254 bits set
createField(tHiBitsMl[254]).assertEquals(0n);
// R = tLo ? R : R - P = (t + 2^255)P = sP
let RminusP = addNonZero(R, negate(P)); // can only be zero if s = 0, which we handle later
R = Provable.if(tLo, Point, R, RminusP);
// now handle the two edge cases s=0 and s=1
let zero = createField(0n);
let zeroPoint = { x: zero, y: zero };
let edgeCaseResult = Provable.if(sLo, Point, P, zeroPoint);
return Provable.if(isEdgeCase, Point, edgeCaseResult, R);
}
/**
* Internal helper to compute `(t + 2^255)*P`.
* `t` is expected to be split into 254 high bits (t >> 1) and a low bit (t & 1).
*
* The gadget proves that `tHi` is in [0, 2^254) but assumes that `tLo` is a single bit.
*
* Optionally, you can specify a different number of high bits by passing in `numHighBits`.
*/
function scaleShifted(
{ x, y }: Point,
{ lowBit: tLo, high254: tHi }: ShiftedScalar,
numHighBits = 254
): Point {
// constant case
if (isConstant(x, y, tHi, tLo)) {
let sP = PallasAffine.scale(
PallasAffine.fromNonzero({ x: x.toBigInt(), y: y.toBigInt() }),
Fq.mod(tLo.toField().toBigInt() + 2n * tHi.toBigInt() + (1n << 255n))
);
return { x: createField(sP.x), y: createField(sP.y) };
}
const Field = getField();
const Point = provable({ x: Field, y: Field });
let zero = createField(0n);
/**
* Strategy:
* - use all 255 rounds of `scaleFastUnpack` for the high part
* - handle two disallowed tHi values separately: -2^254, -2^254 - 1
* - don't handle disallowed tHi = -2^254 - 1/2 because it wouldn't normally be used, as it's > q/2
*/
let equalsMinusShift = tHi.equals(Fq.modulus - (1n << 254n));
let equalsMinusShiftMinus1 = tHi.equals(Fq.modulus - (1n << 254n) - 1n);
let isEdgeCase = equalsMinusShift.or(equalsMinusShiftMinus1);
let tHiSafe = Provable.if(isEdgeCase, zero, tHi);
// R = (2*(t >> 1) + 1 + 2^255)P
// also returns a 255-bit representation of tHi
let [, RMl, [, ...tHiBitsMl]] = Snarky.group.scaleFastUnpack(
[0, x.value, y.value],
[0, tHiSafe.value],
255
);
let P = { x, y };
let R = { x: createField(RMl[1]), y: createField(RMl[2]) };
// prove that tHi has only `numHighBits` bits set
for (let i = numHighBits; i < 255; i++) {
createField(tHiBitsMl[i]).assertEquals(zero);
}
// handle edge cases
// 2*(-2^254) + 1 + 2^255 = 1
// 2*(-2^254 - 1) + 1 + 2^255 = -1
// so the result is (x,+-y)
let edgeCaseY = y.mul(equalsMinusShift.toField().mul(2n).sub(1n)); // y*(2b - 1) = y or -y
let edgeCaseResult = { x, y: edgeCaseY };
R = Provable.if(isEdgeCase, Point, edgeCaseResult, R);
// R = tLo ? R : R - P = (t + 2^255)P
// we also handle a zero R-P result to make the 0 scalar work
let { result: RminusP, isInfinity } = add(R, negate(P));
RminusP = Provable.if(isInfinity, Point, { x: zero, y: zero }, RminusP);
R = Provable.if(tLo, Point, R, RminusP);
return R;
}
/**
* Converts a field element s to a shifted representation t = s - 2^254 mod q,
* where t is represented as a low bit and a 254-bit high part.
*
* This is the representation we use for scalars, since it can be used as input to `scaleShifted()`.
*/
function fieldToShiftedScalar(s: Field): ShiftedScalar {
// constant case
if (s.isConstant()) {
let t = Fq.mod(s.toBigInt() - (1n << 255n));
let lowBit = createBool((t & 1n) === 1n);
let high254 = createField(t >> 1n);
return { lowBit, high254 };
}
// compute t = s + (-2^255 mod q) in (1, 254) arithmetic
let { isOdd: sLo, high: sHi } = isOddAndHigh(s);
let shift = Fq.mod(-(1n << 255n));
assert((shift & 1n) === 0n); // shift happens to be even, so we don't need to worry about a carry
let shiftHi = shift >> 1n;
let tLo = sLo;
let tHi = sHi.add(shiftHi).seal();
// tHi does not overflow:
// tHi = sHi + shiftHi < p/2 + p/2 = p
// sHi < p/2 is guaranteed by isOddAndHigh
assert(shiftHi < Fp.modulus / 2n);
return { lowBit: tLo, high254: tHi };
}
/**
* Converts a 3-limb bigint to a shifted representation t = s - 2^255 mod q,
* where t is represented as a low bit and a 254-bit high part.
*/
function field3ToShiftedScalar(s: Field3): ShiftedScalar {
// constant case
if (Field3.isConstant(s)) {
let t = Fq.mod(Field3.toBigint(s) - (1n << 255n));
let lowBit = createBool((t & 1n) === 1n);
let high254 = createField(t >> 1n);
return { lowBit, high254 };
}
// compute t = s - (2^255 mod q) using foreign field subtraction
let twoTo255 = Field3.from(Fq.mod(1n << 255n));
let t = ForeignField.sub(s, twoTo255, Fq.modulus);
let [t0, t1, t2] = t;
// to fully constrain the output scalar, we need to prove that t is canonical
// otherwise, the subtraction above can add +q to the result, which yields an alternative bit representation
// this also provides a bound on the high part, to that the computation of tHi can't overflow
ForeignField.assertLessThan(t, Fq.modulus);
// split t into 254 high bits and a low bit
// => split t0 into [1, 87] => split t0 into [1, 64, 23] so we can efficiently range-check
let [tLo, tHi00, tHi01] = exists(3, () => {
let t = t0.toBigInt();
return [bit(t, 0), bitSlice(t, 1, 64), bitSlice(t, 65, 23)];
});
let tLoBool = tLo.assertBool();
rangeCheck64(tHi00);
rangeCheck64(tHi01);
// prove (tLo, tHi0) split
// since we know that t0 < 2^88 and tHi0 < 2^128, this even proves that tHi0 < 2^87
// (the bound on tHi0 is necessary so that 2*tHi0 can't overflow)
let tHi0 = tHi00.add(tHi01.mul(1n << 64n));
tLo.add(tHi0.mul(2n)).assertEquals(t0);
// pack tHi
// this can't overflow the native field because:
// -) we showed t < q
// -) the three combined limbs here represent the bigint (t >> 1) < q/2 < p
let tHi = tHi0
.add(t1.mul(1n << (l - 1n)))
.add(t2.mul(1n << (2n * l - 1n)))
.seal();
return { lowBit: tLoBool, high254: tHi };
}
/**
* Wraps the `EC_add` gate to perform complete addition of two non-zero curve points.
*/
function add(g: Point, h: Point): { result: Point; isInfinity: Bool } {
// compute witnesses
let witnesses = exists(7, () => {
let x1 = g.x.toBigInt();
let y1 = g.y.toBigInt();
let x2 = h.x.toBigInt();
let y2 = h.y.toBigInt();
let sameX = BigInt(x1 === x2);
let inf = BigInt(sameX && y1 !== y2);
let infZ = sameX ? Fp.inverse(y2 - y1) ?? 0n : 0n;
let x21Inv = Fp.inverse(x2 - x1) ?? 0n;
let slopeDouble = Fp.div(3n * x1 ** 2n, 2n * y1) ?? 0n;
let slopeAdd = Fp.mul(y2 - y1, x21Inv);
let s = sameX ? slopeDouble : slopeAdd;
let x3 = Fp.mod(s ** 2n - x1 - x2);
let y3 = Fp.mod(s * (x1 - x3) - y1);
return [sameX, inf, infZ, x21Inv, s, x3, y3];
});
let [same_x, inf, inf_z, x21_inv, s, x3, y3] = witnesses;
Snarky.gates.ecAdd(
MlPair(g.x.seal().value, g.y.seal().value),
MlPair(h.x.seal().value, h.y.seal().value),
MlPair(x3.value, y3.value),
inf.value,
same_x.value,
s.value,
inf_z.value,
x21_inv.value
);
// the ecAdd gate constrains `inf` to be boolean
let isInfinity = createBoolUnsafe(inf);
return { result: { x: x3, y: y3 }, isInfinity };
}
/**
* Addition that asserts the result is non-zero.
*/
function addNonZero(g: Point, h: Point) {
let { result, isInfinity } = add(g, h);
isInfinity.assertFalse();
return result;
}
/**
* Negates a point.
*/
function negate(g: Point): Point {
return { x: g.x, y: g.y.neg() };
}