@thi.ng/k-means
Version:
k-means & k-medians with customizable distance functions and centroid initializations for n-D vectors
132 lines (131 loc) • 3.57 kB
JavaScript
import { argmin } from "@thi.ng/distance/argmin";
import { DIST_SQ } from "@thi.ng/distance/squared";
import { assert } from "@thi.ng/errors/assert";
import { SYSTEM } from "@thi.ng/random/system";
import { weightedRandom } from "@thi.ng/random/weighted-random";
import { add } from "@thi.ng/vectors/add";
import { median } from "@thi.ng/vectors/median";
import { mulN } from "@thi.ng/vectors/muln";
import { zeroes } from "@thi.ng/vectors/setn";
const kmeans = (k, samples, opts = {}) => {
let {
dim = samples[0].length,
dist = DIST_SQ,
maxIter = 32,
strategy = means,
exponent,
initial,
rnd
} = opts;
const num = samples.length;
const centroids = Array.isArray(initial) ? initial : initial ? initial(k, samples, dist, rnd) : kmeansPlusPlus(k, samples, dist, rnd, exponent);
assert(centroids.length > 0, `missing initial centroids`);
k = centroids.length;
const clusters = new Uint32Array(num).fill(k);
let update = true;
while (update && maxIter-- > 0) {
update = __assign(samples, centroids, clusters, dist);
if (!update) break;
for (let i = 0; i < k; i++) {
const impl = strategy(dim);
for (let j = 0; j < num; j++) {
i === clusters[j] && impl.update(samples[j]);
}
const centroid = impl.finish();
if (centroid) centroids[i] = centroid;
}
}
return __buildClusters(centroids, clusters);
};
const kmeansPlusPlus = (k, samples, dist = DIST_SQ, rnd = SYSTEM, exponent = 2) => {
const num = samples.length;
assert(num > 0, `missing samples`);
k = Math.min(k, num);
const centroidIDs = [rnd.int() % num];
const centroids = [samples[centroidIDs[0]]];
const indices = new Array(num).fill(0).map((_, i) => i);
const metric = dist.metric;
while (centroidIDs.length < k) {
let psum = 0;
const probs = samples.map((p) => {
const d = dist.from(metric(p, centroids[argmin(p, centroids, dist)])) ** exponent;
psum += d;
return d;
});
if (!psum) break;
let id;
do {
id = weightedRandom(indices, probs, rnd)();
} while (centroidIDs.includes(id));
centroidIDs.push(id);
centroids.push(samples[id]);
}
return centroids;
};
const __assign = (samples, centroids, assignments, dist) => {
let update = false;
for (let i = samples.length; i-- > 0; ) {
const id = argmin(samples[i], centroids, dist);
if (id !== assignments[i]) {
assignments[i] = id;
update = true;
}
}
return update;
};
const __buildClusters = (centroids, assignments) => {
const clusters = [];
for (let i = 0, n = assignments.length; i < n; i++) {
const id = assignments[i];
(clusters[id] || (clusters[id] = {
id,
centroid: centroids[id],
items: []
})).items.push(i);
}
return clusters.filter((x) => !!x);
};
const means = (dim) => {
const acc = zeroes(dim);
let n = 0;
return {
update: (p) => {
add(acc, acc, p);
n++;
},
finish: () => n ? mulN(acc, acc, 1 / n) : void 0
};
};
const medians = () => {
const acc = [];
return {
update: (p) => acc.push(p),
finish: () => acc.length ? median([], acc) : void 0
};
};
const meansLatLon = () => {
let lat = 0;
let lon = 0;
let n = 0;
return {
update: ([$lat, $lon]) => {
lat += $lat < 0 ? $lat + 360 : $lat;
lon += $lon;
n++;
},
finish: () => {
if (!n) return;
lat /= n;
if (lat > 180) lat -= 360;
lon /= n;
return [lat, lon];
}
};
};
export {
kmeans,
kmeansPlusPlus,
means,
meansLatLon,
medians
};