UNPKG

decision-tree

Version:

NodeJS implementation of decision tree, random forest, and XGBoost algorithms with comprehensive performance testing (Node.js 20+)

140 lines (120 loc) 3.57 kB
import _ from 'lodash'; import { TreeNode, DecisionTreeData, TrainingData, NODE_TYPES } from './shared/types.js'; import { randomUUID, prob, log2, mostCommon } from './shared/utils.js'; import { createTree, entropy, gain, maxGain } from './shared/id3-algorithm.js'; /** * Decision Tree Algorithm * @module DecisionTree */ /** * Decision Tree class implementing ID3 algorithm */ class DecisionTree { public static readonly NODE_TYPES = NODE_TYPES; private model!: TreeNode; private data: any[] = []; private target!: string; private features!: string[]; constructor(...args: any[]) { const numArgs = args.length; if (numArgs === 1) { this.import(args[0]); } else if (numArgs === 2) { const [target, features] = args; if (!target || typeof target !== 'string') { throw new Error('`target` argument is expected to be a String. Check documentation on usage'); } if (!features || !Array.isArray(features)) { throw new Error('`features` argument is expected to be an Array<String>. Check documentation on usage'); } this.target = target; this.features = features; } else if (numArgs === 3) { const [data, target, features] = args; const instance = new DecisionTree(target, features); instance.train(data); return instance; } else { throw new Error('Invalid arguments passed to constructor. Check documentation on usage'); } } /** * Trains the decision tree with provided data * @param data - Array of training data objects */ train(data: TrainingData[]): void { if (!data || !Array.isArray(data)) { throw new Error('`data` argument is expected to be an Array<Object>. Check documentation on usage'); } this.model = createTree(data, this.target, this.features); } /** * Predicts class for a given sample * @param sample - Sample data to predict * @returns Predicted class value */ predict(sample: TrainingData): any { let root = this.model; while (root.type !== NODE_TYPES.RESULT) { let attr = root.name; let sampleVal = sample[attr]; let childNode = _.find(root.vals, function (node) { return node.name == sampleVal; }); if (childNode) { root = childNode.child!; } else { root = root.vals![0].child!; } } return root.val; } /** * Evaluates prediction accuracy on samples * @param samples - Array of test samples * @returns Accuracy ratio (correct predictions / total predictions) */ evaluate(samples: TrainingData[]): number { let total = 0; let correct = 0; _.each(samples, (s) => { total++; let pred = this.predict(s); let actual = s[this.target]; if (_.isEqual(pred, actual)) { correct++; } }); return correct / total; } /** * Imports a previously saved model with the toJSON() method * @param json - JSON representation of the model */ import(json: DecisionTreeData): void { const {model, data, target, features} = json; this.model = model; this.data = data; this.target = target; this.features = features; } /** * Returns JSON representation of trained model * @returns JSON object containing model data */ toJSON(): DecisionTreeData { const {data, target, features} = this; const model = this.model; return {model, data, target, features}; } } // Export the DecisionTree class export default DecisionTree;