UNPKG

@btc-vision/btc-runtime

Version:

Bitcoin Smart Contract Runtime

626 lines (500 loc) 17.4 kB
import { u128, u256 } from '@btc-vision/as-bignum/assembly'; export class SafeMath { public static ZERO: u256 = u256.fromU32(0); public static add(a: u256, b: u256): u256 { const c: u256 = u256.add(a, b); if (c < a) { throw new Error('SafeMath: addition overflow'); } return c; } public static add128(a: u128, b: u128): u128 { const c: u128 = u128.add(a, b); if (c < a) { throw new Error('SafeMath: addition overflow'); } return c; } public static add64(a: u64, b: u64): u64 { const c: u64 = a + b; if (c < a) { throw new Error('SafeMath: addition overflow'); } return c; } public static sub(a: u256, b: u256): u256 { if (a < b) { throw new Error('SafeMath: subtraction overflow'); } return u256.sub(a, b); } public static sub128(a: u128, b: u128): u128 { if (a < b) { throw new Error('SafeMath: subtraction overflow'); } return u128.sub(a, b); } public static sub64(a: u64, b: u64): u64 { if (a < b) { throw new Error('SafeMath: subtraction overflow'); } return a - b; } // Computes (a * b) % modulus with full precision public static mulmod(a: u256, b: u256, modulus: u256): u256 { if (u256.eq(modulus, u256.Zero)) throw new Error('SafeMath: modulo by zero'); const mul = SafeMath.mul(a, b); return SafeMath.mod(mul, modulus); } @unsafe @operator('%') public static mod(a: u256, b: u256): u256 { if (u256.eq(b, u256.Zero)) { throw new Error('SafeMath: modulo by zero'); } const divResult = SafeMath.div(a, b); const product = SafeMath.mul(divResult, b); return SafeMath.sub(a, product); } public static modInverse(k: u256, p: u256): u256 { let s = u256.Zero; let old_s = u256.One; let r = p; let old_r = k; while (!r.isZero()) { const quotient = SafeMath.div(old_r, r); // --- Update r --- { // old_r - (quotient * r) const tmp = r; r = u256.sub(old_r, u256.mul(quotient, r)); // unchecked subtract old_r = tmp; } // --- Update s --- { // old_s - (quotient * s) const tmp = s; s = u256.sub(old_s, u256.mul(quotient, s)); // unchecked subtract old_s = tmp; } } // At this point, `old_r` is the gcd(k, p). If gcd != 1 => no inverse // (in a prime field p, gcd=1 if k != 0). // We could enforce this by checking `old_r == 1` but we'll leave it to the caller. // The extended Euclidean algorithm says `old_s` is the inverse (possibly negative), // so we reduce mod p return SafeMath.mod(old_s, p); } public static isEven(a: u256): bool { return u256.and(a, u256.One) == u256.Zero; } public static pow(base: u256, exponent: u256): u256 { let result: u256 = u256.One; while (u256.gt(exponent, u256.Zero)) { if (u256.ne(u256.and(exponent, u256.One), u256.Zero)) { result = SafeMath.mul(result, base); } base = SafeMath.mul(base, base); exponent = u256.shr(exponent, 1); } return result; } public static mul(a: u256, b: u256): u256 { if (a === SafeMath.ZERO || b === SafeMath.ZERO) { return SafeMath.ZERO; } const c: u256 = u256.mul(a, b); const d: u256 = SafeMath.div(c, a); if (u256.ne(d, b)) { throw new Error('SafeMath: multiplication overflow'); } return c; } public static mul128(a: u128, b: u128): u128 { if (a === u128.Zero || b === u128.Zero) { return u128.Zero; } const c: u128 = u128.mul(a, b); const d: u128 = SafeMath.div128(c, a); if (u128.ne(d, b)) { throw new Error('SafeMath: multiplication overflow'); } return c; } public static mul64(a: u64, b: u64): u64 { if (a === 0 || b === 0) { return 0; } const c: u64 = a * b; if (c / a !== b) { throw new Error('SafeMath: multiplication overflow'); } return c; } public static div64(a: u64, b: u64): u64 { if (b === 0) { throw new Error('Division by zero'); } if (a === 0) { return 0; } if (a < b) { return 0; // Return 0 if a < b } if (a === b) { return 1; // Return 1 if a == b } return a / b; } public static div128(a: u128, b: u128): u128 { if (b.isZero()) { throw new Error('Division by zero'); } if (a.isZero()) { return new u128(); } if (u128.lt(a, b)) { return new u128(); // Return 0 if a < b } if (u128.eq(a, b)) { return new u128(1); // Return 1 if a == b } let n = a.clone(); let d = b.clone(); let result = new u128(); const shift = u128.clz(d) - u128.clz(n); d = SafeMath.shl128(d, shift); // align d with n by shifting left for (let i = shift; i >= 0; i--) { if (u128.ge(n, d)) { n = u128.sub(n, d); result = u128.or(result, SafeMath.shl128(u128.One, i)); } d = u128.shr(d, 1); // restore d to original by shifting right } return result; } @unsafe @operator('/') public static div(a: u256, b: u256): u256 { if (b.isZero()) { throw new Error('Division by zero'); } if (a.isZero()) { return new u256(); } if (u256.lt(a, b)) { return new u256(); // Return 0 if a < b } if (u256.eq(a, b)) { return new u256(1); // Return 1 if a == b } let n = a.clone(); let d = b.clone(); let result = new u256(); const shift = u256.clz(d) - u256.clz(n); d = SafeMath.shl(d, shift); // align d with n by shifting left for (let i = shift; i >= 0; i--) { if (u256.ge(n, d)) { n = u256.sub(n, d); result = u256.or(result, SafeMath.shl(u256.One, i)); } d = u256.shr(d, 1); // restore d to original by shifting right } return result; } public static min64(a: u64, b: u64): u64 { return a < b ? a : b; } public static max64(a: u64, b: u64): u64 { return a > b ? a : b; } public static min128(a: u128, b: u128): u128 { return u128.lt(a, b) ? a : b; } public static max128(a: u128, b: u128): u128 { return u128.gt(a, b) ? a : b; } public static min(a: u256, b: u256): u256 { return u256.lt(a, b) ? a : b; } public static max(a: u256, b: u256): u256 { return u256.gt(a, b) ? a : b; } @unsafe public static sqrt(y: u256): u256 { if (u256.gt(y, u256.fromU32(3))) { let z = y; const u246_2 = u256.fromU32(2); const d = SafeMath.div(y, u246_2); let x = SafeMath.add(d, u256.One); while (u256.lt(x, z)) { z = x; const u = SafeMath.div(y, x); const y2 = u256.add(u, x); x = SafeMath.div(y2, u246_2); } return z; } else if (!u256.eq(y, u256.Zero)) { return u256.One; } else { return u256.Zero; } } @unsafe public static shl(value: u256, shift: i32): u256 { // If shift <= 0, no left shift needed (shift=0 => return clone, shift<0 => treat as 0). if (shift <= 0) { return shift == 0 ? value.clone() : new u256(); // or just return value if shift<0 is invalid } // If shift >= 256, the result is zero if (shift >= 256) { return new u256(); } // Now shift is in [1..255]. Masking is optional for clarity: shift &= 255; const bitsPerSegment = 64; const segmentShift = (shift / bitsPerSegment) | 0; const bitShift = shift % bitsPerSegment; const segments = [value.lo1, value.lo2, value.hi1, value.hi2]; const result = SafeMath.shlSegment(segments, segmentShift, bitShift, bitsPerSegment, 4); return new u256(result[0], result[1], result[2], result[3]); } public static shl128(value: u128, shift: i32): u128 { if (shift <= 0) { return shift == 0 ? value.clone() : new u128(); } // Here the total bit width is 128, so shifting >= 128 bits => zero if (shift >= 128) { return new u128(); } // Mask to 0..127 shift &= 127; const bitsPerSegment = 64; const segmentShift = (shift / bitsPerSegment) | 0; const bitShift = shift % bitsPerSegment; const segments = [value.lo, value.hi]; const result = SafeMath.shlSegment(segments, segmentShift, bitShift, bitsPerSegment, 2); return new u128(result[0], result[1]); } public static and(a: u256, b: u256): u256 { return u256.and(a, b); } public static or(a: u256, b: u256): u256 { return u256.or(a, b); } public static xor(a: u256, b: u256): u256 { return u256.xor(a, b); } public static shr(a: u256, shift: i32): u256 { shift &= 255; if (shift == 0) return a; const w = shift >>> 6; // how many full 64-bit words to drop const b = shift & 63; // how many bits to shift within a word // Extract the words let lo1 = a.lo1; let lo2 = a.lo2; let hi1 = a.hi1; let hi2 = a.hi2; // Shift words down by w words // For w = 1, move lo2->lo1, hi1->lo2, hi2->hi1, and hi2 = 0 // For w = 2, move hi1->lo1, hi2->lo2, and zeros in hi1, hi2 // For w = 3, move hi2->lo1 and zeros in others // For w >= 4, everything is zero. if (w >= 4) { // Shifting by >= 256 bits zeros out everything return u256.Zero; } else if (w == 3) { lo1 = hi2; lo2 = 0; hi1 = 0; hi2 = 0; } else if (w == 2) { lo1 = hi1; lo2 = hi2; hi1 = 0; hi2 = 0; } else if (w == 1) { lo1 = lo2; lo2 = hi1; hi1 = hi2; hi2 = 0; } // Now apply the bit shift b if (b > 0) { // Bring down bits from the higher word const carryLo2 = hi1 << (64 - b); const carryLo1 = lo2 << (64 - b); const carryHi1 = hi2 << (64 - b); lo1 = (lo1 >>> b) | carryLo1; lo2 = (lo2 >>> b) | carryLo2; hi1 = (hi1 >>> b) | carryHi1; hi2 = hi2 >>> b; } return new u256(lo1, lo2, hi1, hi2); } /** * Increment a u256 value by 1 * @param value The value to increment * @returns The incremented value */ static inc(value: u256): u256 { if (u256.eq(value, u256.Max)) { throw new Error('SafeMath: increment overflow'); } return value.preInc(); } /** * Decrement a u256 value by 1 * @param value The value to decrement * @returns The decremented value */ public static dec(value: u256): u256 { if (u256.eq(value, u256.Zero)) { throw new Error('SafeMath: decrement overflow'); } return value.preDec(); } /** * Approximates the binary logarithm (log2) of a u256 integer. * @param x - The input value for which to calculate log2(x). * @returns The approximate log2(x) as u256. */ @unsafe public static approximateLog2(x: u256): u256 { // Count the position of the highest bit set let n: u256 = u256.Zero; let value = x; while (u256.gt(value, u256.One)) { value = u256.shr(value, 1); n = SafeMath.add(n, u256.One); } return n; } public static bitLength256(x: u256): u32 { // If zero => bitlength is 0 if (u256.eq(x, u256.Zero)) { return 0; } // hi2 != 0 => top 64 bits => bit positions 192..255 if (x.hi2 != 0) { const partial: u32 = SafeMath.bitLength64(x.hi2); return 192 + partial; } // hi1 != 0 => next 64 bits => bit positions 128..191 if (x.hi1 != 0) { const partial: u32 = SafeMath.bitLength64(x.hi1); return 128 + partial; } // lo2 != 0 => next 64 bits => bit positions 64..127 if (x.lo2 != 0) { const partial: u32 = SafeMath.bitLength64(x.lo2); return 64 + partial; } // else in lo1 => bit positions 0..63 return SafeMath.bitLength64(x.lo1); } public static approxLog(x: u256): u256 { // If x == 0 or x == 1, return 0 (ln(1)=0, ln(0) is undefined but we treat as 0) if (x.isZero() || u256.eq(x, u256.One)) { return u256.Zero; } // 1) Find bit length const bitLen: u32 = SafeMath.bitLength256(x); // if bitLen=0 or 1 => that implies x <=1, but we already handled x=0,1 => just safe-check if (bitLen <= 1) { return u256.Zero; } // 2) ln(x) ~ (bitLen - 1) * ln(2) // We'll store ln(2) in a scaled integer. e.g., LN2_SCALED = 693147 => ln(2)*1e6 const LN2_SCALED: u64 = 693147; // approximate ln(2)*1e6 const log2Count: u64 = (bitLen - 1) as u64; // integer part of log2(x) // Multiply in pure integer return SafeMath.mul(u256.fromU64(log2Count), u256.fromU64(LN2_SCALED)); } /** * Return ln(x) * 1e6 for x>1. If x==0 or 1, returns 0. * Uses: ln(x) = (k * ln(2)) + ln(1 + r), * where k = floor(log2(x)) and r = (x - 2^k)/2^k */ @unsafe // UNTESTED. public static preciseLog(x: u256): u256 { if (x.isZero() || u256.eq(x, u256.One)) { return u256.Zero; } const bitLen = SafeMath.bitLength256(x); if (bitLen <= 1) { return u256.Zero; } // integer part of log2(x) const k: u32 = bitLen - 1; const LN2_SCALED = u256.fromU64(693147); // ln(2)*1e6 const base: u256 = SafeMath.mul(u256.fromU32(k), LN2_SCALED); // 2^k const pow2k = SafeMath.shl(u256.One, <i32>k); const xPrime = SafeMath.sub(x, pow2k); // leftover if (xPrime.isZero()) { // x was exactly 2^k => no fractional part return base; } // rScaled = ((x - 2^k)*1e6)/2^k const xPrimeTimes1e6 = SafeMath.mul(xPrime, u256.fromU64(1_000_000)); const rScaled = SafeMath.div(xPrimeTimes1e6, pow2k); // 0..999999 // approximate ln(1 + r) const frac: u64 = SafeMath.polyLn1p3(rScaled.toU64()); return SafeMath.add(base, u256.fromU64(frac)); } public static pow10(exponent: u8): u256 { let result: u256 = u256.One; for (let i: u8 = 0; i < exponent; i++) { result = SafeMath.mul(result, u256.fromU32(10)); } return result; } /** * polyLn1p3: 3-term polynomial for ln(1 + z), with z in [0,1). * rScaled = z * 1e6 * returns (ln(1+z)) in scale=1e6 */ // UNTESTED. @unsafe public static polyLn1p3(rScaled: u64): u64 { // term1 = z const term1: u64 = rScaled; // term2 => z^2/2 const z2 = term1 * term1; // up to 1e12 const z2Div = (z2 / 1_000_000) >>> 1; // divide by scale and by 2 // term3 => z^3/3 const z3 = z2 * term1; // up to 1e18 const z3Div = z3 / (1_000_000 * 1_000_000) / 3; // => scale // ln(1+z) ~ z - z^2/2 + z^3/3 return term1 - z2Div + z3Div; } private static bitLength64(value: u64): u32 { if (value == 0) return 0; let count: u32 = 0; let temp = value; while (temp > 0) { temp >>>= 1; // logical shift right count++; } return count; } private static shlSegment( segments: u64[], segmentShift: i32, bitShift: i32, bitsPerSegment: i32, fillCount: u8, ): u64[] { const result = new Array<u64>(fillCount).fill(0); for (let i = 0; i < segments.length; i++) { if (i + segmentShift < segments.length) { result[i + segmentShift] |= segments[i] << bitShift; } if (bitShift != 0 && i + segmentShift + 1 < segments.length) { result[i + segmentShift + 1] |= segments[i] >>> (bitsPerSegment - bitShift); } } return result; } }