UNPKG

scalar-autograd

Version:

Scalar-based reverse-mode automatic differentiation in TypeScript.

191 lines (190 loc) 7.43 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.ValueArithmetic = void 0; const Value_1 = require("./Value"); class ValueArithmetic { static add(a, b) { return Value_1.Value.make(a.data + b.data, a, b, (out) => () => { if (a.requiresGrad) a.grad += 1 * out.grad; if (b.requiresGrad) b.grad += 1 * out.grad; }, `(${a.label}+${b.label})`); } static sqrt(a) { if (a.data < 0) { throw new Error(`Cannot take sqrt of negative number: ${a.data}`); } const root = Math.sqrt(a.data); return Value_1.Value.make(root, a, null, (out) => () => { if (a.requiresGrad) a.grad += 0.5 / root * out.grad; }, `sqrt(${a.label})`); } static mul(a, b) { return Value_1.Value.make(a.data * b.data, a, b, (out) => () => { if (a.requiresGrad) a.grad += b.data * out.grad; if (b.requiresGrad) b.grad += a.data * out.grad; }, `(${a.label}*${b.label})`); } static sub(a, b) { return Value_1.Value.make(a.data - b.data, a, b, (out) => () => { if (a.requiresGrad) a.grad += 1 * out.grad; if (b.requiresGrad) b.grad -= 1 * out.grad; }, `(${a.label}-${b.label})`); } static div(a, b, eps = 1e-12) { if (Math.abs(b.data) < eps) { throw new Error(`Division by zero or near-zero encountered in div: denominator=${b.data}`); } const safe = b.data; return Value_1.Value.make(a.data / safe, a, b, (out) => () => { if (a.requiresGrad) a.grad += (1 / safe) * out.grad; if (b.requiresGrad) b.grad -= (a.data / (safe ** 2)) * out.grad; }, `(${a.label}/${b.label})`); } static pow(a, exp) { if (typeof exp !== "number" || Number.isNaN(exp) || !Number.isFinite(exp)) { throw new Error(`Exponent must be a finite number, got ${exp}`); } if (a.data < 0 && Math.abs(exp % 1) > 1e-12) { throw new Error(`Cannot raise negative base (${a.data}) to non-integer exponent (${exp})`); } const safeBase = a.data; return Value_1.Value.make(Math.pow(safeBase, exp), a, null, (out) => () => { if (a.requiresGrad) a.grad += exp * Math.pow(safeBase, exp - 1) * out.grad; }, `(${a.label}^${exp})`); } static powValue(a, b, eps = 1e-12) { if (a.data < 0 && Math.abs(b.data % 1) > eps) { throw new Error(`Cannot raise negative base (${a.data}) to non-integer exponent (${b.data})`); } if (a.data === 0 && b.data <= 0) { throw new Error(`0 cannot be raised to zero or negative power: ${b.data}`); } const safeBase = a.data; return Value_1.Value.make(Math.pow(safeBase, b.data), a, b, (out) => () => { a.grad += b.data * Math.pow(safeBase, b.data - 1) * out.grad; b.grad += Math.log(Math.max(eps, safeBase)) * Math.pow(safeBase, b.data) * out.grad; }, `(${a.label}^${b.label})`); } static mod(a, b) { if (typeof b.data !== 'number' || b.data === 0) { throw new Error(`Modulo by zero encountered`); } return Value_1.Value.make(a.data % b.data, a, b, (out) => () => { a.grad += 1 * out.grad; // No grad to b (modulus not used in most diff cases) }, `(${a.label}%${b.label})`); } static abs(a) { const d = Math.abs(a.data); return Value_1.Value.make(d, a, null, (out) => () => { if (a.requiresGrad) a.grad += (a.data >= 0 ? 1 : -1) * out.grad; }, `abs(${a.label})`); } static exp(a) { const e = Math.exp(a.data); return Value_1.Value.make(e, a, null, (out) => () => { if (a.requiresGrad) a.grad += e * out.grad; }, `exp(${a.label})`); } static log(a, eps = 1e-12) { if (a.data <= 0) { throw new Error(`Logarithm undefined for non-positive value: ${a.data}`); } const safe = Math.max(a.data, eps); const l = Math.log(safe); return Value_1.Value.make(l, a, null, (out) => () => { if (a.requiresGrad) a.grad += (1 / safe) * out.grad; }, `log(${a.label})`); } static min(a, b) { const d = Math.min(a.data, b.data); return Value_1.Value.make(d, a, b, (out) => () => { if (a.requiresGrad) a.grad += (a.data < b.data ? 1 : 0) * out.grad; if (b.requiresGrad) b.grad += (b.data < a.data ? 1 : 0) * out.grad; }, `min(${a.label},${b.label})`); } static max(a, b) { const d = Math.max(a.data, b.data); return Value_1.Value.make(d, a, b, (out) => () => { if (a.requiresGrad) a.grad += (a.data > b.data ? 1 : 0) * out.grad; if (b.requiresGrad) b.grad += (b.data > a.data ? 1 : 0) * out.grad; }, `max(${a.label},${b.label})`); } static floor(a) { const fl = Math.floor(a.data); return Value_1.Value.make(fl, a, null, () => () => { }, `floor(${a.label})`); } static ceil(a) { const cl = Math.ceil(a.data); return Value_1.Value.make(cl, a, null, () => () => { }, `ceil(${a.label})`); } static round(a) { const rd = Math.round(a.data); return Value_1.Value.make(rd, a, null, () => () => { }, `round(${a.label})`); } static square(a) { return ValueArithmetic.pow(a, 2); } static cube(a) { return ValueArithmetic.pow(a, 3); } static reciprocal(a, eps = 1e-12) { if (Math.abs(a.data) < eps) { throw new Error(`Reciprocal of zero or near-zero detected`); } return Value_1.Value.make(1 / a.data, a, null, (out) => () => { if (a.requiresGrad) a.grad += -1 / (a.data * a.data) * out.grad; }, `reciprocal(${a.label})`); } static clamp(a, min, max) { let val = Math.max(min, Math.min(a.data, max)); return Value_1.Value.make(val, a, null, (out) => () => { a.grad += (a.data > min && a.data < max ? 1 : 0) * out.grad; }, `clamp(${a.label},${min},${max})`); } static sum(vals) { if (!vals.length) return new Value_1.Value(0); return vals.reduce((a, b) => a.add(b)); } static mean(vals) { if (!vals.length) return new Value_1.Value(0); return ValueArithmetic.sum(vals).div(vals.length); } static neg(a) { return Value_1.Value.make(-a.data, a, null, (out) => () => { if (a.requiresGrad) a.grad -= out.grad; }, `(-${a.label})`); } static sign(a) { const s = Math.sign(a.data); return Value_1.Value.make(s, a, null, (out) => () => { // The derivative of sign(x) is 0 for x != 0. // At x = 0, the derivative is undefined (Dirac delta), but for practical purposes in ML, // we can define it as 0. if (a.requiresGrad) a.grad += 0 * out.grad; }, `sign(${a.label})`); } } exports.ValueArithmetic = ValueArithmetic;