@n2flowjs/nbase
Version:
Neural Vector Database for efficient similarity search
269 lines (225 loc) • 8.18 kB
text/typescript
import { Vector } from '../types';
/**
* KMeans class for clustering a set of vectors into k clusters using the k-means algorithm.
*
* The k-means algorithm partitions the input data into k clusters by iteratively refining
* cluster centroids to minimize the variance within each cluster. This implementation
* includes the k-means++ initialization method for better centroid selection and supports
* asynchronous processing to avoid blocking the main thread during long computations.
*
* @example
* ```typescript
* const kmeans = new KMeans(3, 100, 0.01);
* const vectors = [
* new Float32Array([1.0, 2.0]),
* new Float32Array([1.5, 1.8]),
* new Float32Array([5.0, 8.0]),
* new Float32Array([8.0, 8.0]),
* new Float32Array([1.0, 0.6]),
* new Float32Array([9.0, 11.0])
* ];
* const centroids = await kmeans.cluster(vectors);
* console.log(centroids);
* ```
*
* @class
* @template Vector - A type representing a numerical vector, such as `Float32Array` or `number[]`.
*
* @property {number} k - The number of clusters to form.
* @property {number} maxIterations - The maximum number of iterations for the algorithm.
* @property {number} tolerance - The threshold for centroid movement to determine convergence.
*
* @constructor
* @param {number} [k=8] - The number of clusters to form.
* @param {number} [maxIterations=100] - The maximum number of iterations for the algorithm.
* @param {number} [tolerance=0.001] - The threshold for centroid movement to determine convergence.
*/
export class KMeans {
private k: number;
private maxIterations: number;
private tolerance: number;
constructor(k: number = 8, maxIterations: number = 100, tolerance: number = 0.001) {
this.k = k;
this.maxIterations = maxIterations;
this.tolerance = tolerance;
}
/**
* Cluster a set of vectors using k-means
* @param vectors - Set of vectors to cluster
* @returns Array of cluster centroids
*/
async cluster(vectors: Vector[]): Promise<Float32Array[]> {
if (vectors.length === 0) {
throw new Error('Cannot cluster empty vector set');
}
if (vectors.length <= this.k) {
// If we have fewer vectors than clusters, return vectors as centroids
return vectors.map((v) => (v instanceof Float32Array ? v : new Float32Array(v)));
}
// Initialize centroids with k-means++ method
const centroids = this._initializeCentroids(vectors);
// Iterative refinement
let iterations = 0;
let changed = true;
while (changed && iterations < this.maxIterations) {
// Assign vectors to nearest centroids
const assignments = this._assignToClusters(vectors, centroids);
// Update centroids based on assignments
changed = this._updateCentroids(vectors, assignments, centroids);
iterations++;
// Allow for async processing to not block main thread
if (iterations % 10 === 0) {
await new Promise((resolve) => setTimeout(resolve, 0));
}
}
return centroids;
}
/**
* Initialize centroids using k-means++ method
* @private
*/
private _initializeCentroids(vectors: Vector[]): Float32Array[] {
const centroids: Float32Array[] = [];
const n = vectors.length;
// Choose first centroid randomly
const firstIdx = Math.floor(Math.random() * n);
centroids.push(vectors[firstIdx] instanceof Float32Array ? (vectors[firstIdx].slice() as Float32Array) : new Float32Array(vectors[firstIdx]));
// KMeans++ initialization
let distances = new Float32Array(n).fill(0);
for (let i = 1; i < this.k; i++) {
let totalDistance = 0;
for (let j = 0; j < n; j++) {
let minDist = Infinity;
for (const centroid of centroids) {
const dist = this._squaredDistance(vectors[j], centroid);
minDist = Math.min(minDist, dist);
}
distances[j] = minDist;
totalDistance += minDist;
}
// Select next centroid with probability proportional to squared distance
let rand = Math.random() * totalDistance;
let nextCentroidIndex = -1;
for (let j = 0; j < n; j++) {
rand -= distances[j];
if (rand <= 0) {
nextCentroidIndex = j;
break;
}
}
if (nextCentroidIndex !== -1) {
centroids.push(vectors[nextCentroidIndex] instanceof Float32Array ? (vectors[nextCentroidIndex].slice() as Float32Array) : new Float32Array(vectors[nextCentroidIndex]));
} else {
// Fallback: choose a random vector
let randomIndex = Math.floor(Math.random() * n);
centroids.push(vectors[randomIndex] instanceof Float32Array ? (vectors[randomIndex].slice() as Float32Array) : new Float32Array(vectors[randomIndex]));
}
}
return centroids;
}
/**
* Assign vectors to nearest centroids
* @private
*/
private _assignToClusters(vectors: Vector[], centroids: Float32Array[]): number[] {
const n = vectors.length;
const assignments = new Array(n);
for (let i = 0; i < n; i++) {
let minDist = Infinity;
let nearestCentroid = 0;
for (let c = 0; c < centroids.length; c++) {
const dist = this._squaredDistance(vectors[i], centroids[c]);
if (dist < minDist) {
minDist = dist;
nearestCentroid = c;
}
}
assignments[i] = nearestCentroid;
}
return assignments;
}
/**
* Update centroids based on assignments
* @private
*/
private _updateCentroids(vectors: Vector[], assignments: number[], centroids: Float32Array[]): boolean {
const n = vectors.length;
const dimensions = vectors[0].length;
const k = centroids.length;
// Count vectors in each cluster
const counts = new Array(k).fill(0);
// Initialize new centroids
const newCentroids: Float32Array[] = [];
for (let c = 0; c < k; c++) {
newCentroids.push(new Float32Array(dimensions));
}
// Sum vectors in each cluster
for (let i = 0; i < n; i++) {
const clusterIdx = assignments[i];
const vector = vectors[i];
counts[clusterIdx]++;
for (let d = 0; d < dimensions; d++) {
newCentroids[clusterIdx][d] += vector[d];
}
}
// Calculate means and check for significant changes
let changed = false;
for (let c = 0; c < k; c++) {
// Handle empty clusters
if (counts[c] === 0) {
// Find the cluster with most points and take a point from there
let maxCount = 0;
let largestCluster = 0;
for (let j = 0; j < k; j++) {
if (counts[j] > maxCount) {
maxCount = counts[j];
largestCluster = j;
}
}
// Find points in largest cluster
const pointsInLargest = [];
for (let i = 0; i < n; i++) {
if (assignments[i] === largestCluster) {
pointsInLargest.push(i);
}
}
// Take a random point from largest cluster
if (pointsInLargest.length > 0) {
const randomIdx = Math.floor(Math.random() * pointsInLargest.length);
const vectorIdx = pointsInLargest[randomIdx];
// Copy this vector as new centroid for empty cluster
for (let d = 0; d < dimensions; d++) {
newCentroids[c][d] = vectors[vectorIdx][d];
}
changed = true;
}
continue;
}
// Calculate mean and check for change
for (let d = 0; d < dimensions; d++) {
newCentroids[c][d] /= counts[c];
// Check if centroid moved significantly
const diff = Math.abs(newCentroids[c][d] - centroids[c][d]);
if (diff > this.tolerance) {
changed = true;
}
}
// Update centroid
centroids[c] = newCentroids[c];
}
return changed;
}
/**
* Calculate squared Euclidean distance between vectors
* @private
*/
private _squaredDistance(a: Vector, b: Vector): number {
let sum = 0;
const len = Math.min(a.length, b.length);
for (let i = 0; i < len; i++) {
const diff = a[i] - b[i];
sum += diff * diff;
}
return sum;
}
}