UNPKG

ml-knn

Version:

k-nearest neighbors algorithm

125 lines (110 loc) 3.63 kB
import { euclidean as euclideanDistance } from 'ml-distance-euclidean'; import KDTree from './KDTree'; export default class KNN { /** * @param {Array} dataset * @param {Array} labels * @param {object} options * @param {number} [options.k=numberOfClasses + 1] - Number of neighbors to classify. * @param {function} [options.distance=euclideanDistance] - Distance function that takes two parameters. */ constructor(dataset, labels, options = {}) { if (dataset === true) { const model = labels; this.kdTree = new KDTree(model.kdTree, options); this.k = model.k; this.classes = new Set(model.classes); this.isEuclidean = model.isEuclidean; return; } const classes = new Set(labels); const { distance = euclideanDistance, k = classes.size + 1 } = options; const points = new Array(dataset.length); for (var i = 0; i < points.length; ++i) { points[i] = dataset[i].slice(); } for (i = 0; i < labels.length; ++i) { points[i].push(labels[i]); } this.kdTree = new KDTree(points, distance); this.k = k; this.classes = classes; this.isEuclidean = distance === euclideanDistance; } /** * Create a new KNN instance with the given model. * @param {object} model * @param {function} distance=euclideanDistance - distance function must be provided if the model wasn't trained with euclidean distance. * @return {KNN} */ static load(model, distance = euclideanDistance) { if (model.name !== 'KNN') { throw new Error(`invalid model: ${model.name}`); } if (!model.isEuclidean && distance === euclideanDistance) { throw new Error( 'a custom distance function was used to create the model. Please provide it again' ); } if (model.isEuclidean && distance !== euclideanDistance) { throw new Error( 'the model was created with the default distance function. Do not load it with another one' ); } return new KNN(true, model, distance); } /** * Return a JSON containing the kd-tree model. * @return {object} JSON KNN model. */ toJSON() { return { name: 'KNN', kdTree: this.kdTree, k: this.k, classes: Array.from(this.classes), isEuclidean: this.isEuclidean }; } /** * Predicts the output given the matrix to predict. * @param {Array} dataset * @return {Array} predictions */ predict(dataset) { if (Array.isArray(dataset)) { if (typeof dataset[0] === 'number') { return getSinglePrediction(this, dataset); } else if ( Array.isArray(dataset[0]) && typeof dataset[0][0] === 'number' ) { const predictions = new Array(dataset.length); for (var i = 0; i < dataset.length; i++) { predictions[i] = getSinglePrediction(this, dataset[i]); } return predictions; } } throw new TypeError('dataset to predict must be an array or a matrix'); } } function getSinglePrediction(knn, currentCase) { var nearestPoints = knn.kdTree.nearest(currentCase, knn.k); var pointsPerClass = {}; var predictedClass = -1; var maxPoints = -1; var lastElement = nearestPoints[0][0].length - 1; for (var element of knn.classes) { pointsPerClass[element] = 0; } for (var i = 0; i < nearestPoints.length; ++i) { var currentClass = nearestPoints[i][0][lastElement]; var currentPoints = ++pointsPerClass[currentClass]; if (currentPoints > maxPoints) { predictedClass = currentClass; maxPoints = currentPoints; } } return predictedClass; }