UNPKG

@thi.ng/dual-algebra

Version:

Multivariate dual number algebra, automatic differentiation

191 lines (190 loc) 4.12 kB
const dual = (real, n = 1, i = 0) => { const out = new Array(n + 1).fill(0, 1); out[0] = real; i > 0 && (out[i] = 1); return out; }; const $ = (r, i = 0) => [r, i === 1 ? 1 : 0]; const $2 = (r, i = 0) => dual(r, 2, i); const $3 = (r, i = 0) => dual(r, 3, i); const $4 = (r, i = 0) => dual(r, 4, i); const defOp = (single, multi, dispatch = 0) => ((...args) => args[dispatch].length < 3 ? single(...args) : multi(...args)); const add = defOp( (a, b) => [a[0] + b[0], a[1] + b[1]], (a, b) => a.map((x, i) => x + b[i]) ); const sub = defOp( (a, b) => [a[0] - b[0], a[1] - b[1]], (a, b) => a.map((x, i) => x - b[i]) ); const neg = defOp( (a) => [-a[0], -a[1]], (a) => a.map((x) => x !== 0 ? -x : 0) ); const mul = defOp( ([ar, ad], [br, bd]) => [ar * br, ar * bd + ad * br], (a, b) => { const ar = a[0]; const br = b[0]; const out = [ar * br]; for (let i = a.length; i-- > 1; ) { out[i] = ar * b[i] + a[i] * br; } return out; } ); const div = defOp( ([ar, ad], [br, bd]) => [ar / br, (ad * br - ar * bd) / (br * br)], (a, b) => { const ar = a[0]; const br = b[0]; const ibr = 1 / (br * br); const out = [ar / br]; for (let i = a.length; i-- > 1; ) { out[i] = (a[i] * br - ar * b[i]) * ibr; } return out; } ); const abs = defOp( ([ar, ad]) => [Math.abs(ar), ad * Math.sign(ar)], (a) => { const s = Math.sign(a[0]); const out = [Math.abs(a[0])]; for (let i = a.length; i-- > 1; ) { out[i] = s * a[i]; } return out; } ); const sqrt = defOp( (a) => { const s = Math.sqrt(a[0]); return [s, 0.5 * a[1] / s]; }, (a) => { const s = Math.sqrt(a[0]); const si = 0.5 / s; const out = [s]; for (let i = a.length; i-- > 1; ) { out[i] = si * a[i]; } return out; } ); const exp = defOp( ([ar, ad]) => { ar = Math.exp(ar); return [ar, ad * ar]; }, (a) => { const ar = Math.exp(a[0]); const out = [ar]; for (let i = a.length; i-- > 1; ) { out[i] = ar * a[i]; } return out; } ); const log = defOp( ([ar, ad]) => [Math.log(ar), ad / ar], (a) => { const ar = Math.log(a[0]); const iar = 1 / ar; const out = [ar]; for (let i = a.length; i-- > 1; ) { out[i] = iar * a[i]; } return out; } ); const pow = defOp( ([ar, ad], k) => [ar ** k, ad * k * ar ** (k - 1)], (a, k) => { const f = k * a[0] ** (k - 1); const out = [a[0] ** k]; for (let i = a.length; i-- > 1; ) { out[i] = f * a[i]; } return out; } ); const sin = defOp( ([ar, ad]) => [Math.sin(ar), ad * Math.cos(ar)], (a) => { const c = Math.cos(a[0]); const out = [Math.sin(a[0])]; for (let i = a.length; i-- > 1; ) { out[i] = c * a[i]; } return out; } ); const cos = defOp( ([ar, ad]) => [Math.cos(ar), -ad * Math.sin(ar)], (a) => { const s = -Math.sin(a[0]); const out = [Math.cos(a[0])]; for (let i = a.length; i-- > 1; ) { out[i] = s * a[i]; } return out; } ); const tan = defOp( ([ar, ad]) => { const c = Math.cos(ar); return [Math.tan(ar), ad / (c * c)]; }, (a) => { const c = Math.cos(a[0]); const ic = 1 / (c * c); const out = [Math.tan(a[0])]; for (let i = a.length; i-- > 1; ) { out[i] = ic * a[i]; } return out; } ); const atan = defOp( ([ar, ad]) => [Math.atan(ar), ad / (1 + ar * ar)], (a) => { const ar = a[0]; const iar = 1 / (1 + ar * ar); const out = [Math.atan(ar)]; for (let i = a.length; i-- > 1; ) { out[i] = iar * a[i]; } return out; } ); const mix = (a, b, t) => add(a, mul(sub(b, a), t)); const evalFn2 = (fn) => (x, y) => fn([x, 1, 0], [y, 0, 1]); const evalFn3 = (fn) => (x, y, z) => fn([x, 1, 0, 0], [y, 0, 1, 0], [z, 0, 0, 1]); const evalFn4 = (fn) => (x, y, z, w) => fn([x, 1, 0, 0, 0], [y, 0, 1, 0, 0], [z, 0, 0, 1, 0], [w, 0, 0, 0, 1]); export { $, $2, $3, $4, abs, add, atan, cos, defOp, div, dual, evalFn2, evalFn3, evalFn4, exp, log, mix, mul, neg, pow, sin, sqrt, sub, tan };