commons-math-interpolation
Version:
A partial port of the Apache Commons Math Interpolation package, including Akima cubic spline interpolation and LOESS/LOWESS local regression.
249 lines (248 loc) • 7.87 kB
JavaScript
import { checkMonotonicallyIncreasing, checkFinite, getMedian } from "./Utils.js";
import { createInterpolatorWithFallback } from "./Index.js";
export function createLoessInterpolator(parms) {
const { interpolationMethod = "akima", minXDistance = getDefaultMinXDistance(parms.xVals), diagInfo } = parms;
const fitYVals = smooth(parms);
const knotFilter = createKnotFilter(parms.xVals, fitYVals, minXDistance);
const knotXVals = filterNumberArray(parms.xVals, knotFilter);
const knotYVals = filterNumberArray(fitYVals, knotFilter);
if (diagInfo) {
diagInfo.fitYVals = fitYVals;
diagInfo.knotFilter = knotFilter;
diagInfo.knotXVals = knotXVals;
diagInfo.knotYVals = knotYVals;
}
return createInterpolatorWithFallback(interpolationMethod, knotXVals, knotYVals);
}
function createKnotFilter(xVals, fitYVals, minXDistance) {
const n = xVals.length;
const filter = Array(n);
let prevX = -Infinity;
for (let i = 0; i < n; i++) {
const x = xVals[i];
const y = fitYVals[i];
if (x - prevX >= minXDistance && !isNaN(y)) {
filter[i] = true;
prevX = x;
}
else {
filter[i] = false;
}
}
return filter;
}
function filterNumberArray(a, filter) {
const n = a.length;
const a2 = new Float64Array(n);
let n2 = 0;
for (let i = 0; i < n; i++) {
if (filter[i]) {
a2[n2++] = a[i];
}
}
return a2.subarray(0, n2);
}
function getDefaultMinXDistance(xVals) {
const n = xVals.length;
if (n == 0) {
return NaN;
}
const xRange = xVals[n - 1] - xVals[0];
if (xRange == 0) {
return 1;
}
return xRange / 100;
}
export function smooth(parms) {
const { xVals, yVals, weights, bandwidthFraction = 0.3, robustnessIters = 2, accuracy = 1E-12, outlierDistanceFactor = 6, diagInfo } = parms;
checkMonotonicallyIncreasing(xVals);
checkFinite(yVals);
if (weights) {
checkFinite(weights);
}
const n = xVals.length;
if (yVals.length != n || weights && weights.length != n) {
throw new Error("Dimension mismatch.");
}
if (diagInfo) {
diagInfo.robustnessIters = 0;
diagInfo.secondLastMedianResidual = undefined;
diagInfo.lastMedianResidual = undefined;
diagInfo.robustnessWeights = undefined;
}
if (n <= 2) {
return Float64Array.from(yVals);
}
let fitYVals = undefined;
for (let iter = 0; iter <= robustnessIters; iter++) {
let robustnessWeights = undefined;
if (iter > 0) {
const residuals = absDiff(fitYVals, yVals);
const medianResidual = getMedian(residuals);
if (medianResidual < accuracy) {
if (diagInfo) {
diagInfo.lastMedianResidual = medianResidual;
}
break;
}
const outlierDistance = medianResidual * outlierDistanceFactor;
robustnessWeights = calculateRobustnessWeights(residuals, outlierDistance);
if (diagInfo) {
diagInfo.robustnessIters = iter;
diagInfo.secondLastMedianResidual = medianResidual;
diagInfo.robustnessWeights = robustnessWeights;
}
}
const combinedWeights = combineWeights(weights, robustnessWeights);
fitYVals = calculateSequenceRegression(xVals, yVals, combinedWeights, bandwidthFraction, accuracy, iter);
}
return fitYVals;
}
function calculateSequenceRegression(xVals, yVals, weights, bandwidthFraction, accuracy, iter) {
const n = xVals.length;
const n2 = weights ? countNonZeros(weights) : n;
if (n2 < 2) {
throw new Error(`Not enough relevant points in iteration ${iter}.`);
}
const bandwidthInPoints = Math.max(2, Math.min(n2, Math.round(n2 * bandwidthFraction)));
const bw = findInitialBandwidthInterval(weights, bandwidthInPoints, n);
const fitYVals = new Float64Array(n);
for (let i = 0; i < n; i++) {
const x = xVals[i];
moveBandwidthInterval(bw, x, xVals, weights);
fitYVals[i] = calculateLocalLinearRegression(xVals, yVals, weights, x, bw.iLeft, bw.iRight, accuracy);
}
return fitYVals;
}
export function calculateLocalLinearRegression(xVals, yVals, weights, x, iLeft, iRight, accuracy) {
let maxDist = Math.max(x - xVals[iLeft], xVals[iRight] - x) * 1.001;
if (maxDist < 0) {
throw new Error("Inconsistent bandwidth parameters.");
}
if (maxDist == 0) {
maxDist = 1;
}
let sumWeights = 0;
let sumX = 0;
let sumXSquared = 0;
let sumY = 0;
let sumXY = 0;
for (let k = iLeft; k <= iRight; ++k) {
const xk = xVals[k];
const yk = yVals[k];
const dist = Math.abs(xk - x);
const w1 = weights ? weights[k] : 1;
const w2 = triCube(dist / maxDist);
const w = w1 * w2;
const xkw = xk * w;
sumWeights += w;
sumX += xkw;
sumXSquared += xk * xkw;
sumY += yk * w;
sumXY += yk * xkw;
}
if (sumWeights < 1E-12) {
return NaN;
}
const meanX = sumX / sumWeights;
const meanY = sumY / sumWeights;
const meanXY = sumXY / sumWeights;
const meanXSquared = sumXSquared / sumWeights;
const meanXSqrDiff = meanXSquared - meanX * meanX;
let beta;
if (Math.abs(meanXSqrDiff) < accuracy ** 2) {
beta = 0;
}
else {
beta = (meanXY - meanX * meanY) / meanXSqrDiff;
}
return meanY + beta * x - beta * meanX;
}
function findInitialBandwidthInterval(weights, bandwidthInPoints, n) {
const iLeft = findNonZero(weights, 0);
if (iLeft >= n) {
throw new Error("Initial bandwidth start point not found.");
}
let iRight = iLeft;
for (let i = 0; i < bandwidthInPoints - 1; i++) {
iRight = findNonZero(weights, iRight + 1);
if (iRight >= n) {
throw new Error("Initial bandwidth end point not found.");
}
}
return { iLeft, iRight };
}
function moveBandwidthInterval(bw, x, xVals, weights) {
const n = xVals.length;
while (true) {
const nextRight = findNonZero(weights, bw.iRight + 1);
if (nextRight >= n || xVals[nextRight] - x >= x - xVals[bw.iLeft]) {
return;
}
bw.iLeft = findNonZero(weights, bw.iLeft + 1);
bw.iRight = nextRight;
}
}
function calculateRobustnessWeights(residuals, outlierDistance) {
const n = residuals.length;
const robustnessWeights = new Float64Array(n);
for (let i = 0; i < n; i++) {
robustnessWeights[i] = biWeight(residuals[i] / outlierDistance);
}
return robustnessWeights;
}
function combineWeights(w1, w2) {
if (!w1 || !w2) {
return w1 ?? w2;
}
const n = w1.length;
const a = new Float64Array(n);
for (let i = 0; i < n; i++) {
a[i] = w1[i] * w2[i];
}
return a;
}
function findNonZero(a, startPos) {
if (!a) {
return startPos;
}
const n = a.length;
let i = startPos;
while (i < n && a[i] == 0) {
i++;
}
return i;
}
function countNonZeros(a) {
let cnt = 0;
for (let i = 0; i < a.length; i++) {
if (a[i] != 0) {
cnt++;
}
}
return cnt;
}
function absDiff(a1, a2) {
const n = a1.length;
const a3 = new Float64Array(n);
for (let i = 0; i < n; i++) {
a3[i] = Math.abs(a1[i] - a2[i]);
}
return a3;
}
function triCube(x) {
const absX = Math.abs(x);
if (absX >= 1) {
return 0;
}
const tmp = 1 - absX * absX * absX;
return tmp * tmp * tmp;
}
function biWeight(x) {
const absX = Math.abs(x);
if (absX >= 1) {
return 0;
}
const tmp = 1 - absX * absX;
return tmp * tmp;
}