UNPKG

ml-knn

Version:

k-nearest neighbors algorithm

280 lines (245 loc) 7.59 kB
/* * Original code from: * * k-d Tree JavaScript - V 1.01 * * https://github.com/ubilabs/kd-tree-javascript * * @author Mircea Pricop <pricop@ubilabs.net>, 2012 * @author Martin Kleppe <kleppe@ubilabs.net>, 2012 * @author Ubilabs http://ubilabs.net, 2012 * @license MIT License <http://www.opensource.org/licenses/mit-license.php> */ function Node(obj, dimension, parent) { this.obj = obj; this.left = null; this.right = null; this.parent = parent; this.dimension = dimension; } export default class KDTree { constructor(points, metric) { // If points is not an array, assume we're loading a pre-built tree if (!Array.isArray(points)) { this.dimensions = points.dimensions; this.root = points; restoreParent(this.root); } else { this.dimensions = new Array(points[0].length); for (var i = 0; i < this.dimensions.length; i++) { this.dimensions[i] = i; } this.root = buildTree(points, 0, null, this.dimensions); } this.metric = metric; } // Convert to a JSON serializable structure; this just requires removing // the `parent` property toJSON() { const result = toJSONImpl(this.root, true); result.dimensions = this.dimensions; return result; } nearest(point, maxNodes, maxDistance) { const metric = this.metric; const dimensions = this.dimensions; var i; const bestNodes = new BinaryHeap(function (e) { return -e[1]; }); function nearestSearch(node) { const dimension = dimensions[node.dimension]; const ownDistance = metric(point, node.obj); const linearPoint = {}; var bestChild, linearDistance, otherChild, i; function saveNode(node, distance) { bestNodes.push([node, distance]); if (bestNodes.size() > maxNodes) { bestNodes.pop(); } } for (i = 0; i < dimensions.length; i += 1) { if (i === node.dimension) { linearPoint[dimensions[i]] = point[dimensions[i]]; } else { linearPoint[dimensions[i]] = node.obj[dimensions[i]]; } } linearDistance = metric(linearPoint, node.obj); if (node.right === null && node.left === null) { if (bestNodes.size() < maxNodes || ownDistance < bestNodes.peek()[1]) { saveNode(node, ownDistance); } return; } if (node.right === null) { bestChild = node.left; } else if (node.left === null) { bestChild = node.right; } else { if (point[dimension] < node.obj[dimension]) { bestChild = node.left; } else { bestChild = node.right; } } nearestSearch(bestChild); if (bestNodes.size() < maxNodes || ownDistance < bestNodes.peek()[1]) { saveNode(node, ownDistance); } if ( bestNodes.size() < maxNodes || Math.abs(linearDistance) < bestNodes.peek()[1] ) { if (bestChild === node.left) { otherChild = node.right; } else { otherChild = node.left; } if (otherChild !== null) { nearestSearch(otherChild); } } } if (maxDistance) { for (i = 0; i < maxNodes; i += 1) { bestNodes.push([null, maxDistance]); } } if (this.root) { nearestSearch(this.root); } const result = []; for (i = 0; i < Math.min(maxNodes, bestNodes.content.length); i += 1) { if (bestNodes.content[i][0]) { result.push([bestNodes.content[i][0].obj, bestNodes.content[i][1]]); } } return result; } } function toJSONImpl(src) { const dest = new Node(src.obj, src.dimension, null); if (src.left) dest.left = toJSONImpl(src.left); if (src.right) dest.right = toJSONImpl(src.right); return dest; } function buildTree(points, depth, parent, dimensions) { const dim = depth % dimensions.length; if (points.length === 0) { return null; } if (points.length === 1) { return new Node(points[0], dim, parent); } points.sort((a, b) => a[dimensions[dim]] - b[dimensions[dim]]); const median = Math.floor(points.length / 2); const node = new Node(points[median], dim, parent); node.left = buildTree(points.slice(0, median), depth + 1, node, dimensions); node.right = buildTree(points.slice(median + 1), depth + 1, node, dimensions); return node; } function restoreParent(root) { if (root.left) { root.left.parent = root; restoreParent(root.left); } if (root.right) { root.right.parent = root; restoreParent(root.right); } } // Binary heap implementation from: // http://eloquentjavascript.net/appendix2.html class BinaryHeap { constructor(scoreFunction) { this.content = []; this.scoreFunction = scoreFunction; } push(element) { // Add the new element to the end of the array. this.content.push(element); // Allow it to bubble up. this.bubbleUp(this.content.length - 1); } pop() { // Store the first element so we can return it later. var result = this.content[0]; // Get the element at the end of the array. var end = this.content.pop(); // If there are any elements left, put the end element at the // start, and let it sink down. if (this.content.length > 0) { this.content[0] = end; this.sinkDown(0); } return result; } peek() { return this.content[0]; } size() { return this.content.length; } bubbleUp(n) { // Fetch the element that has to be moved. var element = this.content[n]; // When at 0, an element can not go up any further. while (n > 0) { // Compute the parent element's index, and fetch it. const parentN = Math.floor((n + 1) / 2) - 1; const parent = this.content[parentN]; // Swap the elements if the parent is greater. if (this.scoreFunction(element) < this.scoreFunction(parent)) { this.content[parentN] = element; this.content[n] = parent; // Update 'n' to continue at the new position. n = parentN; } else { // Found a parent that is less, no need to move it further. break; } } } sinkDown(n) { // Look up the target element and its score. var length = this.content.length; var element = this.content[n]; var elemScore = this.scoreFunction(element); while (true) { // Compute the indices of the child elements. var child2N = (n + 1) * 2; var child1N = child2N - 1; // This is used to store the new position of the element, // if any. var swap = null; // If the first child exists (is inside the array)... if (child1N < length) { // Look it up and compute its score. var child1 = this.content[child1N]; var child1Score = this.scoreFunction(child1); // If the score is less than our element's, we need to swap. if (child1Score < elemScore) { swap = child1N; } } // Do the same checks for the other child. if (child2N < length) { var child2 = this.content[child2N]; var child2Score = this.scoreFunction(child2); if (child2Score < (swap === null ? elemScore : child1Score)) { swap = child2N; } } // If the element needs to be moved, swap it, and continue. if (swap !== null) { this.content[n] = this.content[swap]; this.content[swap] = element; n = swap; } else { // Otherwise, we are done. break; } } } }