UNPKG

@antv/g2

Version:

the Grammar of Graphics in Javascript

148 lines (128 loc) 4.4 kB
import { dot, norm2, weightedSum } from './blas1'; /** minimizes a function using the downhill simplex method */ export function nelderMead(f, x0, parameters?: any) { parameters = parameters || {}; const maxIterations = parameters.maxIterations || x0.length * 200; const nonZeroDelta = parameters.nonZeroDelta || 1.05; const zeroDelta = parameters.zeroDelta || 0.001; const minErrorDelta = parameters.minErrorDelta || 1e-6; const minTolerance = parameters.minErrorDelta || 1e-5; const rho = parameters.rho !== undefined ? parameters.rho : 1; const chi = parameters.chi !== undefined ? parameters.chi : 2; const psi = parameters.psi !== undefined ? parameters.psi : -0.5; const sigma = parameters.sigma !== undefined ? parameters.sigma : 0.5; let maxDiff; // initialize simplex. const N = x0.length; const simplex = new Array(N + 1); simplex[0] = x0; simplex[0].fx = f(x0); simplex[0].id = 0; for (let i = 0; i < N; ++i) { const point = x0.slice(); point[i] = point[i] ? point[i] * nonZeroDelta : zeroDelta; simplex[i + 1] = point; simplex[i + 1].fx = f(point); simplex[i + 1].id = i + 1; } function updateSimplex(value) { for (let i = 0; i < value.length; i++) { simplex[N][i] = value[i]; } simplex[N].fx = value.fx; } const sortOrder = (a, b) => a.fx - b.fx; const centroid = x0.slice(); const reflected = x0.slice(); const contracted = x0.slice(); const expanded = x0.slice(); for (let iteration = 0; iteration < maxIterations; ++iteration) { simplex.sort(sortOrder); if (parameters.history) { // copy the simplex (since later iterations will mutate) and // sort it to have a consistent order between iterations const sortedSimplex = simplex.map((x) => { const state = x.slice(); state.fx = x.fx; state.id = x.id; return state; }); sortedSimplex.sort((a, b) => a.id - b.id); parameters.history.push({ x: simplex[0].slice(), fx: simplex[0].fx, simplex: sortedSimplex, }); } maxDiff = 0; for (let i = 0; i < N; ++i) { maxDiff = Math.max(maxDiff, Math.abs(simplex[0][i] - simplex[1][i])); } if ( Math.abs(simplex[0].fx - simplex[N].fx) < minErrorDelta && maxDiff < minTolerance ) { break; } // compute the centroid of all but the worst point in the simplex for (let i = 0; i < N; ++i) { centroid[i] = 0; for (let j = 0; j < N; ++j) { centroid[i] += simplex[j][i]; } centroid[i] /= N; } // reflect the worst point past the centroid and compute loss at reflected // point const worst = simplex[N]; weightedSum(reflected, 1 + rho, centroid, -rho, worst); reflected.fx = f(reflected); // if the reflected point is the best seen, then possibly expand if (reflected.fx < simplex[0].fx) { weightedSum(expanded, 1 + chi, centroid, -chi, worst); expanded.fx = f(expanded); if (expanded.fx < reflected.fx) { updateSimplex(expanded); } else { updateSimplex(reflected); } } // if the reflected point is worse than the second worst, we need to // contract else if (reflected.fx >= simplex[N - 1].fx) { let shouldReduce = false; if (reflected.fx > worst.fx) { // do an inside contraction weightedSum(contracted, 1 + psi, centroid, -psi, worst); contracted.fx = f(contracted); if (contracted.fx < worst.fx) { updateSimplex(contracted); } else { shouldReduce = true; } } else { // do an outside contraction weightedSum(contracted, 1 - psi * rho, centroid, psi * rho, worst); contracted.fx = f(contracted); if (contracted.fx < reflected.fx) { updateSimplex(contracted); } else { shouldReduce = true; } } if (shouldReduce) { // if we don't contract here, we're done if (sigma >= 1) break; // do a reduction for (let i = 1; i < simplex.length; ++i) { weightedSum(simplex[i], 1 - sigma, simplex[0], sigma, simplex[i]); simplex[i].fx = f(simplex[i]); } } } else { updateSimplex(reflected); } } simplex.sort(sortOrder); return { fx: simplex[0].fx, x: simplex[0] }; }