UNPKG

@thi.ng/k-means

Version:

k-means & k-medians with customizable distance functions and centroid initializations for n-D vectors

132 lines (131 loc) 3.57 kB
import { argmin } from "@thi.ng/distance/argmin"; import { DIST_SQ } from "@thi.ng/distance/squared"; import { assert } from "@thi.ng/errors/assert"; import { SYSTEM } from "@thi.ng/random/system"; import { weightedRandom } from "@thi.ng/random/weighted-random"; import { add } from "@thi.ng/vectors/add"; import { median } from "@thi.ng/vectors/median"; import { mulN } from "@thi.ng/vectors/muln"; import { zeroes } from "@thi.ng/vectors/setn"; const kmeans = (k, samples, opts = {}) => { let { dim = samples[0].length, dist = DIST_SQ, maxIter = 32, strategy = means, exponent, initial, rnd } = opts; const num = samples.length; const centroids = Array.isArray(initial) ? initial : initial ? initial(k, samples, dist, rnd) : kmeansPlusPlus(k, samples, dist, rnd, exponent); assert(centroids.length > 0, `missing initial centroids`); k = centroids.length; const clusters = new Uint32Array(num).fill(k); let update = true; while (update && maxIter-- > 0) { update = __assign(samples, centroids, clusters, dist); if (!update) break; for (let i = 0; i < k; i++) { const impl = strategy(dim); for (let j = 0; j < num; j++) { i === clusters[j] && impl.update(samples[j]); } const centroid = impl.finish(); if (centroid) centroids[i] = centroid; } } return __buildClusters(centroids, clusters); }; const kmeansPlusPlus = (k, samples, dist = DIST_SQ, rnd = SYSTEM, exponent = 2) => { const num = samples.length; assert(num > 0, `missing samples`); k = Math.min(k, num); const centroidIDs = [rnd.int() % num]; const centroids = [samples[centroidIDs[0]]]; const indices = new Array(num).fill(0).map((_, i) => i); const metric = dist.metric; while (centroidIDs.length < k) { let psum = 0; const probs = samples.map((p) => { const d = dist.from(metric(p, centroids[argmin(p, centroids, dist)])) ** exponent; psum += d; return d; }); if (!psum) break; let id; do { id = weightedRandom(indices, probs, rnd)(); } while (centroidIDs.includes(id)); centroidIDs.push(id); centroids.push(samples[id]); } return centroids; }; const __assign = (samples, centroids, assignments, dist) => { let update = false; for (let i = samples.length; i-- > 0; ) { const id = argmin(samples[i], centroids, dist); if (id !== assignments[i]) { assignments[i] = id; update = true; } } return update; }; const __buildClusters = (centroids, assignments) => { const clusters = []; for (let i = 0, n = assignments.length; i < n; i++) { const id = assignments[i]; (clusters[id] || (clusters[id] = { id, centroid: centroids[id], items: [] })).items.push(i); } return clusters.filter((x) => !!x); }; const means = (dim) => { const acc = zeroes(dim); let n = 0; return { update: (p) => { add(acc, acc, p); n++; }, finish: () => n ? mulN(acc, acc, 1 / n) : void 0 }; }; const medians = () => { const acc = []; return { update: (p) => acc.push(p), finish: () => acc.length ? median([], acc) : void 0 }; }; const meansLatLon = () => { let lat = 0; let lon = 0; let n = 0; return { update: ([$lat, $lon]) => { lat += $lat < 0 ? $lat + 360 : $lat; lon += $lon; n++; }, finish: () => { if (!n) return; lat /= n; if (lat > 180) lat -= 360; lon /= n; return [lat, lon]; } }; }; export { kmeans, kmeansPlusPlus, means, meansLatLon, medians };