UNPKG

scalar-autograd

Version:

Scalar-based reverse-mode automatic differentiation in TypeScript.

378 lines (377 loc) 11 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.V = void 0; const ValueArithmetic_1 = require("./ValueArithmetic"); const ValueTrig_1 = require("./ValueTrig"); const ValueActivation_1 = require("./ValueActivation"); const ValueComparison_1 = require("./ValueComparison"); const Value_1 = require("./Value"); class V { static ensureValue(x) { return typeof x === 'number' ? new Value_1.Value(x) : x; } /** * Creates a constant Value (non-differentiable). * @param value The numeric value * @param label Optional label for the value * @returns New constant Value */ static C(value, label = "") { return new Value_1.Value(value, label, false); } /** * Creates a weight Value (differentiable). * @param value The numeric value * @param label Optional label for the value * @returns New differentiable Value */ static W(value, label = "") { return new Value_1.Value(value, label, true); } /** * Addition operation. * @param a First operand * @param b Second operand * @returns New Value with sum */ static add(a, b) { return ValueArithmetic_1.ValueArithmetic.add(V.ensureValue(a), V.ensureValue(b)); } /** * Multiplication operation. * @param a First operand * @param b Second operand * @returns New Value with product */ static mul(a, b) { return ValueArithmetic_1.ValueArithmetic.mul(V.ensureValue(a), V.ensureValue(b)); } /** * Subtraction operation. * @param a First operand * @param b Second operand * @returns New Value with difference */ static sub(a, b) { return ValueArithmetic_1.ValueArithmetic.sub(V.ensureValue(a), V.ensureValue(b)); } /** * Division operation. * @param a Dividend * @param b Divisor * @param eps Small epsilon to prevent division by zero * @returns New Value with quotient */ static div(a, b, eps = 1e-12) { return ValueArithmetic_1.ValueArithmetic.div(V.ensureValue(a), V.ensureValue(b), eps); } /** * Power operation with numeric exponent. * @param a Base * @param exp Exponent * @returns New Value with result */ static pow(a, exp) { return ValueArithmetic_1.ValueArithmetic.pow(V.ensureValue(a), exp); } /** * Power operation with Value exponent. * @param a Base * @param b Exponent * @param eps Small epsilon for logarithm * @returns New Value with result */ static powValue(a, b, eps = 1e-12) { return ValueArithmetic_1.ValueArithmetic.powValue(V.ensureValue(a), V.ensureValue(b), eps); } /** * Modulo operation. * @param a Dividend * @param b Divisor * @returns New Value with remainder */ static mod(a, b) { return ValueArithmetic_1.ValueArithmetic.mod(V.ensureValue(a), V.ensureValue(b)); } /** * Absolute value operation. * @param a Input value * @returns New Value with absolute value */ static abs(a) { return ValueArithmetic_1.ValueArithmetic.abs(V.ensureValue(a)); } /** * Exponential function. * @param a Input value * @returns New Value with e^a */ static exp(a) { return ValueArithmetic_1.ValueArithmetic.exp(V.ensureValue(a)); } /** * Natural logarithm. * @param a Input value * @param eps Small epsilon for numerical stability * @returns New Value with ln(a) */ static log(a, eps = 1e-12) { return ValueArithmetic_1.ValueArithmetic.log(V.ensureValue(a), eps); } /** * Minimum of two values. * @param a First value * @param b Second value * @returns New Value with minimum */ static min(a, b) { return ValueArithmetic_1.ValueArithmetic.min(V.ensureValue(a), V.ensureValue(b)); } /** * Maximum of two values. * @param a First value * @param b Second value * @returns New Value with maximum */ static max(a, b) { return ValueArithmetic_1.ValueArithmetic.max(V.ensureValue(a), V.ensureValue(b)); } /** * Floor function. * @param a Input value * @returns New Value with floor */ static floor(a) { return ValueArithmetic_1.ValueArithmetic.floor(V.ensureValue(a)); } /** * Ceiling function. * @param a Input value * @returns New Value with ceiling */ static ceil(a) { return ValueArithmetic_1.ValueArithmetic.ceil(V.ensureValue(a)); } /** * Round function. * @param a Input value * @returns New Value rounded to nearest integer */ static round(a) { return ValueArithmetic_1.ValueArithmetic.round(V.ensureValue(a)); } /** * Square function. * @param a Input value * @returns New Value with aB2 */ static square(a) { return ValueArithmetic_1.ValueArithmetic.square(V.ensureValue(a)); } /** * Cube function. * @param a Input value * @returns New Value with aB3 */ static cube(a) { return ValueArithmetic_1.ValueArithmetic.cube(V.ensureValue(a)); } /** * Reciprocal function. * @param a Input value * @param eps Small epsilon to prevent division by zero * @returns New Value with 1/a */ static reciprocal(a, eps = 1e-12) { return ValueArithmetic_1.ValueArithmetic.reciprocal(V.ensureValue(a), eps); } /** * Clamp function. * @param a Input value * @param min Minimum bound * @param max Maximum bound * @returns New Value clamped between min and max */ static clamp(a, min, max) { return ValueArithmetic_1.ValueArithmetic.clamp(V.ensureValue(a), min, max); } /** * Negation operation. * @param a Input value * @returns New Value which is negation */ static neg(a) { return ValueArithmetic_1.ValueArithmetic.neg(V.ensureValue(a)); } /** * Sum of array of values. * @param vals Array of values * @returns New Value with sum */ static sum(vals) { return ValueArithmetic_1.ValueArithmetic.sum(vals.map(V.ensureValue)); } /** * Mean of array of values. * @param vals Array of values * @returns New Value with mean */ static mean(vals) { return ValueArithmetic_1.ValueArithmetic.mean(vals.map(V.ensureValue)); } /** * Sine function. * @param x Input value * @returns New Value with sin(x) */ static sin(x) { return ValueTrig_1.ValueTrig.sin(V.ensureValue(x)); } /** * Cosine function. * @param x Input value * @returns New Value with cos(x) */ static cos(x) { return ValueTrig_1.ValueTrig.cos(V.ensureValue(x)); } /** * Tangent function. * @param x Input value * @returns New Value with tan(x) */ static tan(x) { return ValueTrig_1.ValueTrig.tan(V.ensureValue(x)); } /** * Arcsine function. * @param x Input value * @returns New Value with asin(x) */ static asin(x) { return ValueTrig_1.ValueTrig.asin(V.ensureValue(x)); } /** * Arccosine function. * @param x Input value * @returns New Value with acos(x) */ static acos(x) { return ValueTrig_1.ValueTrig.acos(V.ensureValue(x)); } /** * Arctangent function. * @param x Input value * @returns New Value with atan(x) */ static atan(x) { return ValueTrig_1.ValueTrig.atan(V.ensureValue(x)); } /** * ReLU activation function. * @param x Input value * @returns New Value with max(0, x) */ static relu(x) { return ValueActivation_1.ValueActivation.relu(V.ensureValue(x)); } /** * Softplus activation function. * @param x Input value * @returns New Value with ln(1 + e^x) */ static softplus(x) { return ValueActivation_1.ValueActivation.softplus(V.ensureValue(x)); } /** * Hyperbolic tangent function. * @param x Input value * @returns New Value with tanh(x) */ static tanh(x) { return ValueActivation_1.ValueActivation.tanh(V.ensureValue(x)); } /** * Sigmoid activation function. * @param x Input value * @returns New Value with 1/(1 + e^(-x)) */ static sigmoid(x) { return ValueActivation_1.ValueActivation.sigmoid(V.ensureValue(x)); } /** * Equal comparison operation. * @param a First operand * @param b Second operand * @returns New Value with 1 if equal, 0 otherwise */ static eq(a, b) { return ValueComparison_1.ValueComparison.eq(V.ensureValue(a), V.ensureValue(b)); } /** * Not equal comparison operation. * @param a First operand * @param b Second operand * @returns New Value with 1 if not equal, 0 otherwise */ static neq(a, b) { return ValueComparison_1.ValueComparison.neq(V.ensureValue(a), V.ensureValue(b)); } /** * Greater than comparison operation. * @param a First operand * @param b Second operand * @returns New Value with 1 if a > b, 0 otherwise */ static gt(a, b) { return ValueComparison_1.ValueComparison.gt(V.ensureValue(a), V.ensureValue(b)); } /** * Less than comparison operation. * @param a First operand * @param b Second operand * @returns New Value with 1 if a < b, 0 otherwise */ static lt(a, b) { return ValueComparison_1.ValueComparison.lt(V.ensureValue(a), V.ensureValue(b)); } /** * Greater than or equal comparison operation. * @param a First operand * @param b Second operand * @returns New Value with 1 if a >= b, 0 otherwise */ static gte(a, b) { return ValueComparison_1.ValueComparison.gte(V.ensureValue(a), V.ensureValue(b)); } /** * Less than or equal comparison operation. * @param a First operand * @param b Second operand * @returns New Value with 1 if a <= b, 0 otherwise */ static lte(a, b) { return ValueComparison_1.ValueComparison.lte(V.ensureValue(a), V.ensureValue(b)); } static ifThenElse(cond, thenVal, elseVal) { // cond: Value, thenVal: Value, elseVal: Value return ValueComparison_1.ValueComparison.ifThenElse(V.ensureValue(cond), V.ensureValue(thenVal), V.ensureValue(elseVal)); } /** * Square root function. * @param a Input value * @returns New Value with sqrt(a) */ static sqrt(a) { return ValueArithmetic_1.ValueArithmetic.sqrt(V.ensureValue(a)); } /** * Sign function. * @param a Input value * @returns New Value with sign(a) */ static sign(a) { return ValueArithmetic_1.ValueArithmetic.sign(V.ensureValue(a)); } } exports.V = V;