scalar-autograd
Version:
Scalar-based reverse-mode automatic differentiation in TypeScript.
378 lines (377 loc) • 11 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.V = void 0;
const ValueArithmetic_1 = require("./ValueArithmetic");
const ValueTrig_1 = require("./ValueTrig");
const ValueActivation_1 = require("./ValueActivation");
const ValueComparison_1 = require("./ValueComparison");
const Value_1 = require("./Value");
class V {
static ensureValue(x) {
return typeof x === 'number' ? new Value_1.Value(x) : x;
}
/**
* Creates a constant Value (non-differentiable).
* @param value The numeric value
* @param label Optional label for the value
* @returns New constant Value
*/
static C(value, label = "") {
return new Value_1.Value(value, label, false);
}
/**
* Creates a weight Value (differentiable).
* @param value The numeric value
* @param label Optional label for the value
* @returns New differentiable Value
*/
static W(value, label = "") {
return new Value_1.Value(value, label, true);
}
/**
* Addition operation.
* @param a First operand
* @param b Second operand
* @returns New Value with sum
*/
static add(a, b) {
return ValueArithmetic_1.ValueArithmetic.add(V.ensureValue(a), V.ensureValue(b));
}
/**
* Multiplication operation.
* @param a First operand
* @param b Second operand
* @returns New Value with product
*/
static mul(a, b) {
return ValueArithmetic_1.ValueArithmetic.mul(V.ensureValue(a), V.ensureValue(b));
}
/**
* Subtraction operation.
* @param a First operand
* @param b Second operand
* @returns New Value with difference
*/
static sub(a, b) {
return ValueArithmetic_1.ValueArithmetic.sub(V.ensureValue(a), V.ensureValue(b));
}
/**
* Division operation.
* @param a Dividend
* @param b Divisor
* @param eps Small epsilon to prevent division by zero
* @returns New Value with quotient
*/
static div(a, b, eps = 1e-12) {
return ValueArithmetic_1.ValueArithmetic.div(V.ensureValue(a), V.ensureValue(b), eps);
}
/**
* Power operation with numeric exponent.
* @param a Base
* @param exp Exponent
* @returns New Value with result
*/
static pow(a, exp) {
return ValueArithmetic_1.ValueArithmetic.pow(V.ensureValue(a), exp);
}
/**
* Power operation with Value exponent.
* @param a Base
* @param b Exponent
* @param eps Small epsilon for logarithm
* @returns New Value with result
*/
static powValue(a, b, eps = 1e-12) {
return ValueArithmetic_1.ValueArithmetic.powValue(V.ensureValue(a), V.ensureValue(b), eps);
}
/**
* Modulo operation.
* @param a Dividend
* @param b Divisor
* @returns New Value with remainder
*/
static mod(a, b) {
return ValueArithmetic_1.ValueArithmetic.mod(V.ensureValue(a), V.ensureValue(b));
}
/**
* Absolute value operation.
* @param a Input value
* @returns New Value with absolute value
*/
static abs(a) {
return ValueArithmetic_1.ValueArithmetic.abs(V.ensureValue(a));
}
/**
* Exponential function.
* @param a Input value
* @returns New Value with e^a
*/
static exp(a) {
return ValueArithmetic_1.ValueArithmetic.exp(V.ensureValue(a));
}
/**
* Natural logarithm.
* @param a Input value
* @param eps Small epsilon for numerical stability
* @returns New Value with ln(a)
*/
static log(a, eps = 1e-12) {
return ValueArithmetic_1.ValueArithmetic.log(V.ensureValue(a), eps);
}
/**
* Minimum of two values.
* @param a First value
* @param b Second value
* @returns New Value with minimum
*/
static min(a, b) {
return ValueArithmetic_1.ValueArithmetic.min(V.ensureValue(a), V.ensureValue(b));
}
/**
* Maximum of two values.
* @param a First value
* @param b Second value
* @returns New Value with maximum
*/
static max(a, b) {
return ValueArithmetic_1.ValueArithmetic.max(V.ensureValue(a), V.ensureValue(b));
}
/**
* Floor function.
* @param a Input value
* @returns New Value with floor
*/
static floor(a) {
return ValueArithmetic_1.ValueArithmetic.floor(V.ensureValue(a));
}
/**
* Ceiling function.
* @param a Input value
* @returns New Value with ceiling
*/
static ceil(a) {
return ValueArithmetic_1.ValueArithmetic.ceil(V.ensureValue(a));
}
/**
* Round function.
* @param a Input value
* @returns New Value rounded to nearest integer
*/
static round(a) {
return ValueArithmetic_1.ValueArithmetic.round(V.ensureValue(a));
}
/**
* Square function.
* @param a Input value
* @returns New Value with a B2
*/
static square(a) {
return ValueArithmetic_1.ValueArithmetic.square(V.ensureValue(a));
}
/**
* Cube function.
* @param a Input value
* @returns New Value with a B3
*/
static cube(a) {
return ValueArithmetic_1.ValueArithmetic.cube(V.ensureValue(a));
}
/**
* Reciprocal function.
* @param a Input value
* @param eps Small epsilon to prevent division by zero
* @returns New Value with 1/a
*/
static reciprocal(a, eps = 1e-12) {
return ValueArithmetic_1.ValueArithmetic.reciprocal(V.ensureValue(a), eps);
}
/**
* Clamp function.
* @param a Input value
* @param min Minimum bound
* @param max Maximum bound
* @returns New Value clamped between min and max
*/
static clamp(a, min, max) {
return ValueArithmetic_1.ValueArithmetic.clamp(V.ensureValue(a), min, max);
}
/**
* Negation operation.
* @param a Input value
* @returns New Value which is negation
*/
static neg(a) {
return ValueArithmetic_1.ValueArithmetic.neg(V.ensureValue(a));
}
/**
* Sum of array of values.
* @param vals Array of values
* @returns New Value with sum
*/
static sum(vals) {
return ValueArithmetic_1.ValueArithmetic.sum(vals.map(V.ensureValue));
}
/**
* Mean of array of values.
* @param vals Array of values
* @returns New Value with mean
*/
static mean(vals) {
return ValueArithmetic_1.ValueArithmetic.mean(vals.map(V.ensureValue));
}
/**
* Sine function.
* @param x Input value
* @returns New Value with sin(x)
*/
static sin(x) {
return ValueTrig_1.ValueTrig.sin(V.ensureValue(x));
}
/**
* Cosine function.
* @param x Input value
* @returns New Value with cos(x)
*/
static cos(x) {
return ValueTrig_1.ValueTrig.cos(V.ensureValue(x));
}
/**
* Tangent function.
* @param x Input value
* @returns New Value with tan(x)
*/
static tan(x) {
return ValueTrig_1.ValueTrig.tan(V.ensureValue(x));
}
/**
* Arcsine function.
* @param x Input value
* @returns New Value with asin(x)
*/
static asin(x) {
return ValueTrig_1.ValueTrig.asin(V.ensureValue(x));
}
/**
* Arccosine function.
* @param x Input value
* @returns New Value with acos(x)
*/
static acos(x) {
return ValueTrig_1.ValueTrig.acos(V.ensureValue(x));
}
/**
* Arctangent function.
* @param x Input value
* @returns New Value with atan(x)
*/
static atan(x) {
return ValueTrig_1.ValueTrig.atan(V.ensureValue(x));
}
/**
* ReLU activation function.
* @param x Input value
* @returns New Value with max(0, x)
*/
static relu(x) {
return ValueActivation_1.ValueActivation.relu(V.ensureValue(x));
}
/**
* Softplus activation function.
* @param x Input value
* @returns New Value with ln(1 + e^x)
*/
static softplus(x) {
return ValueActivation_1.ValueActivation.softplus(V.ensureValue(x));
}
/**
* Hyperbolic tangent function.
* @param x Input value
* @returns New Value with tanh(x)
*/
static tanh(x) {
return ValueActivation_1.ValueActivation.tanh(V.ensureValue(x));
}
/**
* Sigmoid activation function.
* @param x Input value
* @returns New Value with 1/(1 + e^(-x))
*/
static sigmoid(x) {
return ValueActivation_1.ValueActivation.sigmoid(V.ensureValue(x));
}
/**
* Equal comparison operation.
* @param a First operand
* @param b Second operand
* @returns New Value with 1 if equal, 0 otherwise
*/
static eq(a, b) {
return ValueComparison_1.ValueComparison.eq(V.ensureValue(a), V.ensureValue(b));
}
/**
* Not equal comparison operation.
* @param a First operand
* @param b Second operand
* @returns New Value with 1 if not equal, 0 otherwise
*/
static neq(a, b) {
return ValueComparison_1.ValueComparison.neq(V.ensureValue(a), V.ensureValue(b));
}
/**
* Greater than comparison operation.
* @param a First operand
* @param b Second operand
* @returns New Value with 1 if a > b, 0 otherwise
*/
static gt(a, b) {
return ValueComparison_1.ValueComparison.gt(V.ensureValue(a), V.ensureValue(b));
}
/**
* Less than comparison operation.
* @param a First operand
* @param b Second operand
* @returns New Value with 1 if a < b, 0 otherwise
*/
static lt(a, b) {
return ValueComparison_1.ValueComparison.lt(V.ensureValue(a), V.ensureValue(b));
}
/**
* Greater than or equal comparison operation.
* @param a First operand
* @param b Second operand
* @returns New Value with 1 if a >= b, 0 otherwise
*/
static gte(a, b) {
return ValueComparison_1.ValueComparison.gte(V.ensureValue(a), V.ensureValue(b));
}
/**
* Less than or equal comparison operation.
* @param a First operand
* @param b Second operand
* @returns New Value with 1 if a <= b, 0 otherwise
*/
static lte(a, b) {
return ValueComparison_1.ValueComparison.lte(V.ensureValue(a), V.ensureValue(b));
}
static ifThenElse(cond, thenVal, elseVal) {
// cond: Value, thenVal: Value, elseVal: Value
return ValueComparison_1.ValueComparison.ifThenElse(V.ensureValue(cond), V.ensureValue(thenVal), V.ensureValue(elseVal));
}
/**
* Square root function.
* @param a Input value
* @returns New Value with sqrt(a)
*/
static sqrt(a) {
return ValueArithmetic_1.ValueArithmetic.sqrt(V.ensureValue(a));
}
/**
* Sign function.
* @param a Input value
* @returns New Value with sign(a)
*/
static sign(a) {
return ValueArithmetic_1.ValueArithmetic.sign(V.ensureValue(a));
}
}
exports.V = V;