@thi.ng/dual-algebra
Version:
Multivariate dual number algebra, automatic differentiation
191 lines (190 loc) • 4.12 kB
JavaScript
const dual = (real, n = 1, i = 0) => {
const out = new Array(n + 1).fill(0, 1);
out[0] = real;
i > 0 && (out[i] = 1);
return out;
};
const $ = (r, i = 0) => [r, i === 1 ? 1 : 0];
const $2 = (r, i = 0) => dual(r, 2, i);
const $3 = (r, i = 0) => dual(r, 3, i);
const $4 = (r, i = 0) => dual(r, 4, i);
const defOp = (single, multi, dispatch = 0) => ((...args) => args[dispatch].length < 3 ? single(...args) : multi(...args));
const add = defOp(
(a, b) => [a[0] + b[0], a[1] + b[1]],
(a, b) => a.map((x, i) => x + b[i])
);
const sub = defOp(
(a, b) => [a[0] - b[0], a[1] - b[1]],
(a, b) => a.map((x, i) => x - b[i])
);
const neg = defOp(
(a) => [-a[0], -a[1]],
(a) => a.map((x) => x !== 0 ? -x : 0)
);
const mul = defOp(
([ar, ad], [br, bd]) => [ar * br, ar * bd + ad * br],
(a, b) => {
const ar = a[0];
const br = b[0];
const out = [ar * br];
for (let i = a.length; i-- > 1; ) {
out[i] = ar * b[i] + a[i] * br;
}
return out;
}
);
const div = defOp(
([ar, ad], [br, bd]) => [ar / br, (ad * br - ar * bd) / (br * br)],
(a, b) => {
const ar = a[0];
const br = b[0];
const ibr = 1 / (br * br);
const out = [ar / br];
for (let i = a.length; i-- > 1; ) {
out[i] = (a[i] * br - ar * b[i]) * ibr;
}
return out;
}
);
const abs = defOp(
([ar, ad]) => [Math.abs(ar), ad * Math.sign(ar)],
(a) => {
const s = Math.sign(a[0]);
const out = [Math.abs(a[0])];
for (let i = a.length; i-- > 1; ) {
out[i] = s * a[i];
}
return out;
}
);
const sqrt = defOp(
(a) => {
const s = Math.sqrt(a[0]);
return [s, 0.5 * a[1] / s];
},
(a) => {
const s = Math.sqrt(a[0]);
const si = 0.5 / s;
const out = [s];
for (let i = a.length; i-- > 1; ) {
out[i] = si * a[i];
}
return out;
}
);
const exp = defOp(
([ar, ad]) => {
ar = Math.exp(ar);
return [ar, ad * ar];
},
(a) => {
const ar = Math.exp(a[0]);
const out = [ar];
for (let i = a.length; i-- > 1; ) {
out[i] = ar * a[i];
}
return out;
}
);
const log = defOp(
([ar, ad]) => [Math.log(ar), ad / ar],
(a) => {
const ar = Math.log(a[0]);
const iar = 1 / ar;
const out = [ar];
for (let i = a.length; i-- > 1; ) {
out[i] = iar * a[i];
}
return out;
}
);
const pow = defOp(
([ar, ad], k) => [ar ** k, ad * k * ar ** (k - 1)],
(a, k) => {
const f = k * a[0] ** (k - 1);
const out = [a[0] ** k];
for (let i = a.length; i-- > 1; ) {
out[i] = f * a[i];
}
return out;
}
);
const sin = defOp(
([ar, ad]) => [Math.sin(ar), ad * Math.cos(ar)],
(a) => {
const c = Math.cos(a[0]);
const out = [Math.sin(a[0])];
for (let i = a.length; i-- > 1; ) {
out[i] = c * a[i];
}
return out;
}
);
const cos = defOp(
([ar, ad]) => [Math.cos(ar), -ad * Math.sin(ar)],
(a) => {
const s = -Math.sin(a[0]);
const out = [Math.cos(a[0])];
for (let i = a.length; i-- > 1; ) {
out[i] = s * a[i];
}
return out;
}
);
const tan = defOp(
([ar, ad]) => {
const c = Math.cos(ar);
return [Math.tan(ar), ad / (c * c)];
},
(a) => {
const c = Math.cos(a[0]);
const ic = 1 / (c * c);
const out = [Math.tan(a[0])];
for (let i = a.length; i-- > 1; ) {
out[i] = ic * a[i];
}
return out;
}
);
const atan = defOp(
([ar, ad]) => [Math.atan(ar), ad / (1 + ar * ar)],
(a) => {
const ar = a[0];
const iar = 1 / (1 + ar * ar);
const out = [Math.atan(ar)];
for (let i = a.length; i-- > 1; ) {
out[i] = iar * a[i];
}
return out;
}
);
const mix = (a, b, t) => add(a, mul(sub(b, a), t));
const evalFn2 = (fn) => (x, y) => fn([x, 1, 0], [y, 0, 1]);
const evalFn3 = (fn) => (x, y, z) => fn([x, 1, 0, 0], [y, 0, 1, 0], [z, 0, 0, 1]);
const evalFn4 = (fn) => (x, y, z, w) => fn([x, 1, 0, 0, 0], [y, 0, 1, 0, 0], [z, 0, 0, 1, 0], [w, 0, 0, 0, 1]);
export {
$,
$2,
$3,
$4,
abs,
add,
atan,
cos,
defOp,
div,
dual,
evalFn2,
evalFn3,
evalFn4,
exp,
log,
mix,
mul,
neg,
pow,
sin,
sqrt,
sub,
tan
};