scalar-autograd
Version:
Scalar-based reverse-mode automatic differentiation in TypeScript.
48 lines (47 loc) • 1.83 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.ValueComparison = void 0;
const Value_1 = require("./Value");
class ValueComparison {
static eq(a, b) {
return Value_1.Value.make(a.data === b.data ? 1 : 0, a, b, (out) => () => {
// No gradient - discrete operation
}, `(${a.label}==${b.label})`);
}
static ifThenElse(cond, thenVal, elseVal) {
return Value_1.Value.make(cond.data ? thenVal.data : elseVal.data, cond, cond.data ? thenVal : elseVal, (out) => () => {
if (cond.data) {
thenVal.grad += out.grad;
}
else {
elseVal.grad += out.grad;
}
}, `if(${cond.label}){${thenVal.label}}else{${elseVal.label}}`);
}
static neq(a, b) {
return Value_1.Value.make(a.data !== b.data ? 1 : 0, a, b, (out) => () => {
// No gradient - discrete operation
}, `(${a.label}!=${b.label})`);
}
static gt(a, b) {
return Value_1.Value.make(a.data > b.data ? 1 : 0, a, b, (out) => () => {
// No gradient - discrete operation
}, `(${a.label}>${b.label})`);
}
static lt(a, b) {
return Value_1.Value.make(a.data < b.data ? 1 : 0, a, b, (out) => () => {
// No gradient - discrete operation
}, `(${a.label}<${b.label})`);
}
static gte(a, b) {
return Value_1.Value.make(a.data >= b.data ? 1 : 0, a, b, (out) => () => {
// No gradient - discrete operation
}, `(${a.label}>=${b.label})`);
}
static lte(a, b) {
return Value_1.Value.make(a.data <= b.data ? 1 : 0, a, b, (out) => () => {
// No gradient - discrete operation
}, `(${a.label}<=${b.label})`);
}
}
exports.ValueComparison = ValueComparison;