UNPKG

ml-knn

Version:

k-nearest neighbors algorithm

407 lines (356 loc) 11.2 kB
'use strict'; var mlDistanceEuclidean = require('ml-distance-euclidean'); /* * Original code from: * * k-d Tree JavaScript - V 1.01 * * https://github.com/ubilabs/kd-tree-javascript * * @author Mircea Pricop <pricop@ubilabs.net>, 2012 * @author Martin Kleppe <kleppe@ubilabs.net>, 2012 * @author Ubilabs http://ubilabs.net, 2012 * @license MIT License <http://www.opensource.org/licenses/mit-license.php> */ function Node(obj, dimension, parent) { this.obj = obj; this.left = null; this.right = null; this.parent = parent; this.dimension = dimension; } class KDTree { constructor(points, metric) { // If points is not an array, assume we're loading a pre-built tree if (!Array.isArray(points)) { this.dimensions = points.dimensions; this.root = points; restoreParent(this.root); } else { this.dimensions = new Array(points[0].length); for (var i = 0; i < this.dimensions.length; i++) { this.dimensions[i] = i; } this.root = buildTree(points, 0, null, this.dimensions); } this.metric = metric; } // Convert to a JSON serializable structure; this just requires removing // the `parent` property toJSON() { const result = toJSONImpl(this.root); result.dimensions = this.dimensions; return result; } nearest(point, maxNodes, maxDistance) { const metric = this.metric; const dimensions = this.dimensions; var i; const bestNodes = new BinaryHeap(function (e) { return -e[1]; }); function nearestSearch(node) { const dimension = dimensions[node.dimension]; const ownDistance = metric(point, node.obj); const linearPoint = {}; var bestChild, linearDistance, otherChild, i; function saveNode(node, distance) { bestNodes.push([node, distance]); if (bestNodes.size() > maxNodes) { bestNodes.pop(); } } for (i = 0; i < dimensions.length; i += 1) { if (i === node.dimension) { linearPoint[dimensions[i]] = point[dimensions[i]]; } else { linearPoint[dimensions[i]] = node.obj[dimensions[i]]; } } linearDistance = metric(linearPoint, node.obj); if (node.right === null && node.left === null) { if (bestNodes.size() < maxNodes || ownDistance < bestNodes.peek()[1]) { saveNode(node, ownDistance); } return; } if (node.right === null) { bestChild = node.left; } else if (node.left === null) { bestChild = node.right; } else { if (point[dimension] < node.obj[dimension]) { bestChild = node.left; } else { bestChild = node.right; } } nearestSearch(bestChild); if (bestNodes.size() < maxNodes || ownDistance < bestNodes.peek()[1]) { saveNode(node, ownDistance); } if ( bestNodes.size() < maxNodes || Math.abs(linearDistance) < bestNodes.peek()[1] ) { if (bestChild === node.left) { otherChild = node.right; } else { otherChild = node.left; } if (otherChild !== null) { nearestSearch(otherChild); } } } if (maxDistance) { for (i = 0; i < maxNodes; i += 1) { bestNodes.push([null, maxDistance]); } } if (this.root) { nearestSearch(this.root); } const result = []; for (i = 0; i < Math.min(maxNodes, bestNodes.content.length); i += 1) { if (bestNodes.content[i][0]) { result.push([bestNodes.content[i][0].obj, bestNodes.content[i][1]]); } } return result; } } function toJSONImpl(src) { const dest = new Node(src.obj, src.dimension, null); if (src.left) dest.left = toJSONImpl(src.left); if (src.right) dest.right = toJSONImpl(src.right); return dest; } function buildTree(points, depth, parent, dimensions) { const dim = depth % dimensions.length; if (points.length === 0) { return null; } if (points.length === 1) { return new Node(points[0], dim, parent); } points.sort((a, b) => a[dimensions[dim]] - b[dimensions[dim]]); const median = Math.floor(points.length / 2); const node = new Node(points[median], dim, parent); node.left = buildTree(points.slice(0, median), depth + 1, node, dimensions); node.right = buildTree(points.slice(median + 1), depth + 1, node, dimensions); return node; } function restoreParent(root) { if (root.left) { root.left.parent = root; restoreParent(root.left); } if (root.right) { root.right.parent = root; restoreParent(root.right); } } // Binary heap implementation from: // http://eloquentjavascript.net/appendix2.html class BinaryHeap { constructor(scoreFunction) { this.content = []; this.scoreFunction = scoreFunction; } push(element) { // Add the new element to the end of the array. this.content.push(element); // Allow it to bubble up. this.bubbleUp(this.content.length - 1); } pop() { // Store the first element so we can return it later. var result = this.content[0]; // Get the element at the end of the array. var end = this.content.pop(); // If there are any elements left, put the end element at the // start, and let it sink down. if (this.content.length > 0) { this.content[0] = end; this.sinkDown(0); } return result; } peek() { return this.content[0]; } size() { return this.content.length; } bubbleUp(n) { // Fetch the element that has to be moved. var element = this.content[n]; // When at 0, an element can not go up any further. while (n > 0) { // Compute the parent element's index, and fetch it. const parentN = Math.floor((n + 1) / 2) - 1; const parent = this.content[parentN]; // Swap the elements if the parent is greater. if (this.scoreFunction(element) < this.scoreFunction(parent)) { this.content[parentN] = element; this.content[n] = parent; // Update 'n' to continue at the new position. n = parentN; } else { // Found a parent that is less, no need to move it further. break; } } } sinkDown(n) { // Look up the target element and its score. var length = this.content.length; var element = this.content[n]; var elemScore = this.scoreFunction(element); while (true) { // Compute the indices of the child elements. var child2N = (n + 1) * 2; var child1N = child2N - 1; // This is used to store the new position of the element, // if any. var swap = null; // If the first child exists (is inside the array)... if (child1N < length) { // Look it up and compute its score. var child1 = this.content[child1N]; var child1Score = this.scoreFunction(child1); // If the score is less than our element's, we need to swap. if (child1Score < elemScore) { swap = child1N; } } // Do the same checks for the other child. if (child2N < length) { var child2 = this.content[child2N]; var child2Score = this.scoreFunction(child2); if (child2Score < (swap === null ? elemScore : child1Score)) { swap = child2N; } } // If the element needs to be moved, swap it, and continue. if (swap !== null) { this.content[n] = this.content[swap]; this.content[swap] = element; n = swap; } else { // Otherwise, we are done. break; } } } } class KNN { /** * @param {Array} dataset * @param {Array} labels * @param {object} options * @param {number} [options.k=numberOfClasses + 1] - Number of neighbors to classify. * @param {function} [options.distance=euclideanDistance] - Distance function that takes two parameters. */ constructor(dataset, labels, options = {}) { if (dataset === true) { const model = labels; this.kdTree = new KDTree(model.kdTree, options); this.k = model.k; this.classes = new Set(model.classes); this.isEuclidean = model.isEuclidean; return; } const classes = new Set(labels); const { distance = mlDistanceEuclidean.euclidean, k = classes.size + 1 } = options; const points = new Array(dataset.length); for (var i = 0; i < points.length; ++i) { points[i] = dataset[i].slice(); } for (i = 0; i < labels.length; ++i) { points[i].push(labels[i]); } this.kdTree = new KDTree(points, distance); this.k = k; this.classes = classes; this.isEuclidean = distance === mlDistanceEuclidean.euclidean; } /** * Create a new KNN instance with the given model. * @param {object} model * @param {function} distance=euclideanDistance - distance function must be provided if the model wasn't trained with euclidean distance. * @return {KNN} */ static load(model, distance = mlDistanceEuclidean.euclidean) { if (model.name !== 'KNN') { throw new Error(`invalid model: ${model.name}`); } if (!model.isEuclidean && distance === mlDistanceEuclidean.euclidean) { throw new Error( 'a custom distance function was used to create the model. Please provide it again' ); } if (model.isEuclidean && distance !== mlDistanceEuclidean.euclidean) { throw new Error( 'the model was created with the default distance function. Do not load it with another one' ); } return new KNN(true, model, distance); } /** * Return a JSON containing the kd-tree model. * @return {object} JSON KNN model. */ toJSON() { return { name: 'KNN', kdTree: this.kdTree, k: this.k, classes: Array.from(this.classes), isEuclidean: this.isEuclidean }; } /** * Predicts the output given the matrix to predict. * @param {Array} dataset * @return {Array} predictions */ predict(dataset) { if (Array.isArray(dataset)) { if (typeof dataset[0] === 'number') { return getSinglePrediction(this, dataset); } else if ( Array.isArray(dataset[0]) && typeof dataset[0][0] === 'number' ) { const predictions = new Array(dataset.length); for (var i = 0; i < dataset.length; i++) { predictions[i] = getSinglePrediction(this, dataset[i]); } return predictions; } } throw new TypeError('dataset to predict must be an array or a matrix'); } } function getSinglePrediction(knn, currentCase) { var nearestPoints = knn.kdTree.nearest(currentCase, knn.k); var pointsPerClass = {}; var predictedClass = -1; var maxPoints = -1; var lastElement = nearestPoints[0][0].length - 1; for (var element of knn.classes) { pointsPerClass[element] = 0; } for (var i = 0; i < nearestPoints.length; ++i) { var currentClass = nearestPoints[i][0][lastElement]; var currentPoints = ++pointsPerClass[currentClass]; if (currentPoints > maxPoints) { predictedClass = currentClass; maxPoints = currentPoints; } } return predictedClass; } module.exports = KNN;