UNPKG

@__username/decision-tree

Version:

NodeJS implementation of decision tree using ID3 algorithm

257 lines (217 loc) 5.44 kB
var _ = require('lodash'); /** * ID3 Decision Tree Algorithm * @module DecisionTreeID3 */ module.exports = (function() { /** * Map of valid tree node types * @constant * @static */ const NODE_TYPES = DecisionTreeID3.NODE_TYPES = { RESULT: 'result', FEATURE: 'feature', FEATURE_VALUE: 'feature_value' }; /** * Underlying model * @private */ var model; /** * @constructor * @return {DecisionTreeID3} */ function DecisionTreeID3(data, target, features) { this.data = data; this.target = target; this.features = features; model = createTree(data, target, features); } /** * @public API */ DecisionTreeID3.prototype = { /** * Predicts class for sample */ predict: function(sample) { var root = model; while (root.type !== NODE_TYPES.RESULT) { var attr = root.name; var sampleVal = sample[attr]; var childNode = _.find(root.vals, function(node) { return node.name == sampleVal }); if (childNode){ root = childNode.child; } else { root = root.vals[0].child; } } return root.val; }, /** * Evalutes prediction accuracy on samples */ evaluate: function(samples) { var instance = this; var target = this.target; var total = 0; var correct = 0; _.each(samples, function(s) { total++; var pred = instance.predict(s); var actual = s[target]; if (pred == actual) { correct++; } }); return correct / total; }, /** * Returns JSON representation of trained model */ toJSON: function() { return model; } }; /** * Creates a new tree * @private */ function createTree(data, target, features) { var targets = _.uniq(_.map(data, target)); if (targets.length == 1) { return { type: NODE_TYPES.RESULT, val: targets[0], name: targets[0], alias: targets[0] + randomUUID() }; } if (features.length == 0) { var topTarget = mostCommon(targets); return { type: NODE_TYPES.RESULT, val: topTarget, name: topTarget, alias: topTarget + randomUUID() }; } var bestFeature = maxGain(data, target, features); var remainingFeatures = _.without(features, bestFeature); var possibleValues = _.uniq(_.map(data, bestFeature)); var node = { name: bestFeature, alias: bestFeature + randomUUID() }; node.type = NODE_TYPES.FEATURE; node.vals = _.map(possibleValues, function(v) { var _newS = data.filter(function(x) { return x[bestFeature] == v }); var child_node = { name: v, alias: v + randomUUID(), type: NODE_TYPES.FEATURE_VALUE }; child_node.child = createTree(_newS, target, remainingFeatures); return child_node; }); return node; } /** * Computes entropy of a list * @private */ function entropy(vals) { var uniqueVals = _.uniq(vals); var probs = uniqueVals.map(function(x) { return prob(x, vals) }); var logVals = probs.map(function(p) { return -p * log2(p) }); return logVals.reduce(function(a, b) { return a + b }, 0); } /** * Computes gain * @private */ function gain(data, target, feature) { var attrVals = _.uniq(_.map(data, feature)); var setEntropy = entropy(_.map(data, target)); var setSize = _.size(data); var entropies = attrVals.map(function(n) { var subset = data.filter(function(x) { return x[feature] === n }); return (subset.length / setSize) * entropy(_.map(subset, target)); }); var sumOfEntropies = entropies.reduce(function(a, b) { return a + b }, 0); return setEntropy - sumOfEntropies; } /** * Computes Max gain across features to determine best split * @private */ function maxGain(data, target, features) { return _.max(features, function(element) { return gain(data, target, element) }); } /** * Computes probability of of a given value existing in a given list * @private */ function prob(value, list) { var occurrences = _.filter(list, function(element) { return element === value }); var numOccurrences = occurrences.length; var numElements = list.length; return numOccurrences / numElements; } /** * Computes Log with base-2 * @private */ function log2(n) { return Math.log(n) / Math.log(2); } /** * Finds element with highest occurrence in a list * @private */ function mostCommon(list) { var elementFrequencyMap = {}; var largestFrequency = -1; var mostCommonElement = null; list.forEach(function(element) { var elementFrequency = (elementFrequencyMap[element] || 0) + 1; elementFrequencyMap[element] = elementFrequency; if (largestFrequency < elementFrequency) { mostCommonElement = element; largestFrequency = elementFrequency; } }); return mostCommonElement; } /** * Generates random UUID * @private */ function randomUUID() { return "_r" + Math.random().toString(32).slice(2); } /** * @class DecisionTreeID3 */ return DecisionTreeID3; })();