UNPKG

decision-tree

Version:

NodeJS implementation of decision tree, random forest, and XGBoost algorithms with comprehensive performance testing (Node.js 20+)

692 lines (535 loc) 24.3 kB
# Machine Learning Algorithms for Node.js A comprehensive Node.js library implementing three powerful machine learning algorithms: **Decision Tree**, **Random Forest**, and **XGBoost**. Built with TypeScript and featuring extensive performance testing, this library provides production-ready implementations with full type safety and comprehensive test coverage. ## Table of Contents - [🚀 Features](#-features) - [Installation](#installation) - [TypeScript Support](#typescript-support) - [Quick Start](#quick-start) - [Usage](#usage) - [Decision Tree](#decision-tree-usage) - [Random Forest](#random-forest-usage) - [XGBoost](#xgboost-usage) - [Algorithm Comparison](#algorithm-comparison) - [Performance Benchmarks](#performance-benchmarks) - [Test Coverage](#test-coverage) - [Development](#development) - [Contributing](#contributing) - [Why Node.js 20+?](#why-nodejs-20) ## 🚀 Features - **Three ML Algorithms**: Decision Tree (ID3), Random Forest, and XGBoost - **TypeScript Support**: Full type safety and IntelliSense support - **Performance Optimized**: Comprehensive performance testing with strict benchmarks - **Production Ready**: 408 tests with 100% pass rate and extensive edge case coverage - **Model Persistence**: Export/import trained models as JSON - **Feature Importance**: Built-in feature importance calculation for all algorithms - **Early Stopping**: XGBoost early stopping to prevent overfitting - **Regularization**: L1 and L2 regularization support in XGBoost - **ES Modules**: Modern JavaScript with native ES module support ## Installation **Requires Node.js 20+ or Bun 1.0+** (ES modules support required) ### Using npm ```bash npm install decision-tree ``` ### Using Bun ```bash bun add decision-tree ``` ## TypeScript Support This module is written in TypeScript and provides full type definitions. The compiled JavaScript maintains full backward compatibility with existing Node.js and browser projects that support ES modules. **Note:** This package uses ES modules (`"type": "module"`), so CommonJS `require()` is not supported. ### TypeScript Usage ```typescript import DecisionTree from 'decision-tree'; import RandomForest from 'decision-tree/random-forest'; import XGBoost from 'decision-tree/xgboost'; // Full type safety for training data interface TrainingData { color: string; shape: string; size: string; liked: boolean; } const training_data: TrainingData[] = [ {"color":"blue", "shape":"square", "size":"small", "liked":false}, {"color":"red", "shape":"square", "size":"large", "liked":false}, {"color":"blue", "shape":"circle", "size":"medium", "liked":true}, {"color":"red", "shape":"circle", "size":"small", "liked":true} ]; // Decision Tree const dt = new DecisionTree('liked', ['color', 'shape', 'size']); dt.train(training_data); const prediction = dt.predict({ color: "blue", shape: "hexagon", size: "medium" }); // Random Forest const rf = new RandomForest('liked', ['color', 'shape', 'size'], { nEstimators: 100, maxFeatures: 'sqrt', randomState: 42 }); rf.train(training_data); const rfPrediction = rf.predict({ color: "blue", shape: "hexagon", size: "medium" }); // XGBoost const xgb = new XGBoost('liked', ['color', 'shape', 'size'], { nEstimators: 100, learningRate: 0.1, objective: 'binary' }); xgb.train(training_data); const xgbPrediction = xgb.predict({ color: "blue", shape: "hexagon", size: "medium" }); ``` ## Quick Start ```js import DecisionTree from 'decision-tree'; import RandomForest from 'decision-tree/random-forest'; import XGBoost from 'decision-tree/xgboost'; // Sample data const data = [ {"color":"blue", "shape":"square", "liked":false}, {"color":"red", "shape":"square", "liked":false}, {"color":"blue", "shape":"circle", "liked":true}, {"color":"red", "shape":"circle", "liked":true} ]; // Train and predict with Decision Tree const dt = new DecisionTree('liked', ['color', 'shape']); dt.train(data); const prediction = dt.predict({ color: "blue", shape: "hexagon" }); // Train and predict with Random Forest const rf = new RandomForest('liked', ['color', 'shape'], { nEstimators: 100 }); rf.train(data); const rfPrediction = rf.predict({ color: "blue", shape: "hexagon" }); // Train and predict with XGBoost const xgb = new XGBoost('liked', ['color', 'shape'], { nEstimators: 100, objective: 'binary' }); xgb.train(data); const xgbPrediction = xgb.predict({ color: "blue", shape: "hexagon" }); ``` ## Usage ### Decision Tree Usage ```js import DecisionTree from 'decision-tree'; ``` **Important:** This package uses ES modules only. CommonJS `require()` is not supported. ### Prepare training dataset ```js const training_data = [ {"color":"blue", "shape":"square", "liked":false}, {"color":"red", "shape":"square", "liked":false}, {"color":"blue", "shape":"circle", "liked":true}, {"color":"red", "shape":"circle", "liked":true}, {"color":"blue", "shape":"hexagon", "liked":false}, {"color":"red", "shape":"hexagon", "liked":false}, {"color":"yellow", "shape":"hexagon", "liked":true}, {"color":"yellow", "shape":"circle", "liked":true} ]; ``` ### Prepare test dataset ```js const test_data = [ {"color":"blue", "shape":"hexagon", "liked":false}, {"color":"red", "shape":"hexagon", "liked":false}, {"color":"yellow", "shape":"hexagon", "liked":true}, {"color":"yellow", "shape":"circle", "liked":true} ]; ``` ### Setup Target Class used for prediction ```js const class_name = "liked"; ``` ### Setup Features to be used by decision tree ```js const features = ["color", "shape"]; ``` ### Create decision tree and train the model **Method 1: Separate instantiation and training** ```js const dt = new DecisionTree(class_name, features); dt.train(training_data); ``` **Method 2: Instantiate and train in one step** ```js const dt = new DecisionTree(training_data, class_name, features); ``` **Note:** Method 2 returns a new instance rather than modifying the current one. This is equivalent to: ```js const dt = new DecisionTree(class_name, features); dt.train(training_data); ``` ### Predict class label for an instance ```js const predicted_class = dt.predict({ color: "blue", shape: "hexagon" }); ``` ### Evaluate model on a dataset ```js const accuracy = dt.evaluate(test_data); ``` ### Export underlying model for visualization or inspection ```js const treeJson = dt.toJSON(); ``` **Note:** The exported model contains the tree structure but does not preserve the original training data. Only imported models have training data stored. ### Create a decision tree from a previously trained model ```js const treeJson = dt.toJSON(); const preTrainedDecisionTree = new DecisionTree(treeJson); ``` ### Import a previously trained model on an existing tree instance ```js const treeJson = dt.toJSON(); dt.import(treeJson); ``` ### Random Forest Usage This package includes a Random Forest implementation that provides better performance and reduced overfitting compared to single Decision Trees. ### Import Random Forest ```js import RandomForest from 'decision-tree/random-forest'; ``` ### Basic Random Forest Usage ```js const training_data = [ {"color":"blue", "shape":"square", "liked":false}, {"color":"red", "shape":"square", "liked":false}, {"color":"blue", "shape":"circle", "liked":true}, {"color":"red", "shape":"circle", "liked":true}, {"color":"blue", "shape":"hexagon", "liked":false}, {"color":"red", "shape":"hexagon", "liked":false}, {"color":"yellow", "shape":"hexagon", "liked":true}, {"color":"yellow", "shape":"circle", "liked":true} ]; const test_data = [ {"color":"blue", "shape":"hexagon", "liked":false}, {"color":"yellow", "shape":"circle", "liked":true} ]; const class_name = "liked"; const features = ["color", "shape"]; // Create and train Random Forest const rf = new RandomForest(class_name, features); rf.train(training_data); // Make predictions const predicted_class = rf.predict({ color: "blue", shape: "hexagon" }); // Evaluate accuracy const accuracy = rf.evaluate(test_data); console.log(`Accuracy: ${(accuracy * 100).toFixed(1)}%`); ``` ### Random Forest Configuration ```js const config = { nEstimators: 100, // Number of trees (default: 100) maxFeatures: 'sqrt', // Features per split: 'sqrt', 'log2', 'auto', or number bootstrap: true, // Use bootstrap sampling (default: true) randomState: 42, // Random seed for reproducibility maxDepth: undefined, // Maximum tree depth minSamplesSplit: 2 // Minimum samples to split }; const rf = new RandomForest(class_name, features, config); rf.train(training_data); ``` ### Random Forest Features ```js // Get feature importance scores const importance = rf.getFeatureImportance(); console.log('Feature importance:', importance); // Get number of trees const treeCount = rf.getTreeCount(); console.log(`Number of trees: ${treeCount}`); // Get configuration const config = rf.getConfig(); console.log('Configuration:', config); ``` ### Random Forest Model Persistence ```js // Export model const modelJson = rf.toJSON(); // Import model const newRf = new RandomForest(modelJson); // Or import into existing instance rf.import(modelJson); ``` ### Random Forest vs Decision Tree Random Forest typically provides: - **Better accuracy** through ensemble learning - **Reduced overfitting** via bootstrap sampling and feature randomization - **More stable predictions** through majority voting - **Feature importance** scores across the ensemble - **Parallel training** capability for better performance ### XGBoost Usage XGBoost (eXtreme Gradient Boosting) is a powerful gradient boosting algorithm that builds an ensemble of decision trees sequentially, where each tree corrects the errors of the previous ones. ### Basic XGBoost Usage ```js import XGBoost from 'decision-tree/xgboost'; // Basic usage const xgb = new XGBoost('liked', ['color', 'shape', 'size']); xgb.train(training_data); // Make predictions const prediction = xgb.predict({ color: 'blue', shape: 'hexagon', size: 'medium' }); // Evaluate accuracy const accuracy = xgb.evaluate(test_data); console.log(`Accuracy: ${(accuracy * 100).toFixed(1)}%`); ``` ### XGBoost Configuration ```js const config = { nEstimators: 100, // Number of boosting rounds (default: 100) learningRate: 0.1, // Step size shrinkage (default: 0.1) maxDepth: 6, // Maximum tree depth (default: 6) minChildWeight: 1, // Minimum sum of instance weight in leaf (default: 1) subsample: 1.0, // Fraction of samples for each tree (default: 1.0) colsampleByTree: 1.0, // Fraction of features for each tree (default: 1.0) regAlpha: 0, // L1 regularization (default: 0) regLambda: 1, // L2 regularization (default: 1) objective: 'regression', // Loss function: 'regression', 'binary', 'multiclass' earlyStoppingRounds: 10, // Early stopping patience (default: undefined) randomState: 42, // Random seed for reproducibility validationFraction: 0.2 // Fraction for validation set (default: 0.2) }; const xgb = new XGBoost('liked', ['color', 'shape', 'size'], config); xgb.train(training_data); ``` ### XGBoost Features ```js // Get feature importance scores const importance = xgb.getFeatureImportance(); console.log('Feature importance:', importance); // Get boosting history const history = xgb.getBoostingHistory(); console.log('Training loss:', history.trainLoss); console.log('Validation loss:', history.validationLoss); // Get best iteration (useful with early stopping) const bestIteration = xgb.getBestIteration(); console.log(`Best iteration: ${bestIteration}`); // Get number of trees const treeCount = xgb.getTreeCount(); console.log(`Number of trees: ${treeCount}`); // Get configuration const config = xgb.getConfig(); console.log('Configuration:', config); ``` ### XGBoost Model Persistence ```js // Export model const modelJson = xgb.toJSON(); // Import model const newXgb = new XGBoost(modelJson); // Or import into existing instance xgb.import(modelJson); ``` ## Algorithm Comparison Choose the right algorithm for your use case: | Feature | Decision Tree | Random Forest | XGBoost | |---------|---------------|---------------|---------| | **Best For** | Simple data, interpretability | General purpose, balanced performance | Complex data, highest accuracy | | **Algorithm** | Single tree (ID3) | Ensemble of trees | Gradient boosting | | **Overfitting** | Prone to overfitting | Reduces overfitting | Best overfitting control | | **Accuracy** | Good on simple data | Better on complex data | Best on complex data | | **Interpretability** | Highly interpretable | Less interpretable | Least interpretable | | **Training Time** | < 100ms | < 500ms | < 1000ms | | **Prediction Time** | < 10ms | < 50ms | < 20ms | | **Stability** | Less stable | More stable | Most stable | | **Feature Selection** | All features | Random subset per tree | Random subset per tree | | **Bootstrap Sampling** | No | Yes (by default) | Yes (configurable) | | **Parallel Training** | No | Yes (trees independent) | No (sequential) | | **Regularization** | No | No | Yes (L1, L2) | | **Early Stopping** | No | No | Yes | | **Learning Rate** | N/A | N/A | Yes | | **Gradient Boosting** | No | No | Yes | ### When to Use Each Algorithm **Decision Tree**: Use when you need interpretable models, have simple datasets, or require fast training/prediction. **Random Forest**: Use as a general-purpose solution that provides good accuracy with reduced overfitting and built-in feature importance. **XGBoost**: Use when you need the highest possible accuracy on complex datasets and can afford longer training times. ## Data Validation and Limitations **Important:** This implementation is intentionally permissive and has limited validation: - **Feature names:** Only validates that features is an array, not element types - **Target column:** Does not validate that the target column exists in training data - **Empty datasets:** Allows empty training datasets (may result in unexpected behavior) - **Data types:** Accepts mixed data types without validation For production use, ensure your data meets these requirements: - Training data must be an array of objects - Each object should contain the target column - Feature values should be consistent across samples ## Error Handling The package handles many edge cases gracefully but may fail silently in some scenarios: ```js // This will work but may not produce expected results const dt = new DecisionTree('nonexistent', ['feature1']); dt.train([{ feature1: 'value1' }]); // Missing target column // This will work but may not produce expected results const dt2 = new DecisionTree('target', ['feature1']); dt2.train([]); // Empty dataset ``` ## Test Coverage This project maintains comprehensive test coverage to ensure reliability and correctness: ### Current Test Statistics - **Total Tests:** 408 passing tests - **Test Categories:** 15 comprehensive test suites covering Decision Trees, Random Forests, XGBoost, and Performance - **Test Framework:** Mocha with TypeScript support - **Coverage Areas:** - Core decision tree functionality - Random Forest ensemble learning - XGBoost gradient boosting - Data validation and sanitization - Edge cases and error handling - Performance and scalability - Type safety and interface validation - Model persistence and import/export - Prediction edge cases - ID3 algorithm correctness - Bootstrap sampling and feature selection - Majority voting and ensemble prediction ### Test Suites | Test Suite | Description | Test Count | |------------|-------------|------------| | **Data Validation & Sanitization** | Input validation, feature validation, data type handling | 12 tests | | **Decision Tree Basics** | Core functionality, initialization, training, prediction | 9 tests | | **Edge Cases & Error Handling** | Empty datasets, missing features, invalid inputs | 8 tests | | **Sample Dataset Tests** | Real-world dataset validation (Tic-tac-toe, Voting, Object Evaluation) | 7 tests | | **ID3 Algorithm Tests** | Entropy calculations, feature selection, tree structure | 9 tests | | **Model Persistence** | Import/export functionality, data integrity | 15 tests | | **Performance & Scalability** | Large datasets, memory management, concurrent operations | 12 tests | | **Prediction Edge Cases** | Missing features, unknown values, data type mismatches | 12 tests | | **Type Safety & Interface Validation** | TypeScript type checking, interface consistency | 10 tests | | **Reported Bugs** | Regression tests for previously reported issues | 2 tests | | **Random Forest Basics** | Core Random Forest functionality, configuration, training | 10 tests | | **Random Forest Configuration** | Different parameter combinations and edge cases | 9 tests | | **Random Forest Bootstrap Sampling** | Bootstrap sampling with and without replacement | 3 tests | | **Random Forest Feature Selection** | Random feature selection strategies | 4 tests | | **Random Forest Ensemble Prediction** | Majority voting and prediction stability | 3 tests | | **Random Forest Feature Importance** | Feature importance calculation and normalization | 3 tests | | **Random Forest Model Persistence** | Export/import functionality for Random Forest models | 3 tests | | **Random Forest Edge Cases** | Edge cases specific to Random Forest implementation | 15 tests | | **Random Forest Performance** | Performance testing with large numbers of estimators | 2 tests | | **Random Forest on Sample Datasets** | Real-world dataset validation with Random Forest | 3 tests | | **Random Forest Utility Functions** | Bootstrap sampling, feature selection, majority voting utilities | 20 tests | | **XGBoost Basics** | Core XGBoost functionality, configuration, training | 10 tests | | **XGBoost Configuration** | Different parameter combinations and edge cases | 11 tests | | **XGBoost Gradient Boosting** | Gradient boosting iterations and loss tracking | 3 tests | | **XGBoost Early Stopping** | Early stopping functionality and validation | 3 tests | | **XGBoost Feature Importance** | Feature importance calculation for XGBoost | 3 tests | | **XGBoost Model Persistence** | Export/import functionality for XGBoost models | 4 tests | | **XGBoost Edge Cases** | Edge cases specific to XGBoost implementation | 5 tests | | **XGBoost Performance** | Performance testing with large numbers of estimators | 2 tests | | **XGBoost on Sample Datasets** | Real-world dataset validation with XGBoost | 3 tests | | **XGBoost Loss Functions** | Loss functions (MSE, Logistic, Cross-Entropy) | 15 tests | | **XGBoost Gradient Boosting Utils** | Gradient boosting utility functions | 8 tests | | **XGBoost Edge Cases - Empty Datasets** | Empty and invalid dataset handling | 7 tests | | **XGBoost Edge Cases - Configuration** | Configuration edge cases and validation | 20 tests | | **XGBoost Edge Cases - Prediction** | Prediction edge cases and validation | 9 tests | | **XGBoost Edge Cases - Model Persistence** | Model persistence edge cases | 9 tests | | **XGBoost Edge Cases - Feature Importance** | Feature importance edge cases | 3 tests | | **XGBoost Edge Cases - Boosting History** | Boosting history edge cases | 3 tests | | **XGBoost Edge Cases - Performance** | Performance edge cases | 4 tests | | **Performance Tests - Decision Tree** | Decision Tree performance benchmarks | 6 tests | | **Performance Tests - Random Forest** | Random Forest performance benchmarks | 7 tests | | **Performance Tests - XGBoost** | XGBoost performance benchmarks | 8 tests | | **Performance Tests - Algorithm Comparison** | Cross-algorithm performance comparison | 4 tests | | **Performance Tests - Edge Cases** | Performance edge cases and stress tests | 3 tests | ### Performance Benchmarks The library includes comprehensive performance tests to ensure all algorithms meet speed requirements: - **Decision Tree**: < 100ms training, < 10ms prediction - **Random Forest**: < 500ms training, < 50ms prediction - **XGBoost**: < 1000ms training, < 20ms prediction - **Memory Usage**: < 50MB for large datasets - **Scalability**: Linear scaling with dataset size and tree count Performance tests cover: - Training time benchmarks for small, medium, and large datasets - Prediction speed tests with multiple iterations - Memory usage monitoring for large datasets - Algorithm comparison tests (Decision Tree vs Random Forest vs XGBoost) - Concurrent operations and edge case performance - Early stopping and regularization efficiency ### Running Tests #### Using npm ```bash # Run all tests npm test # Run tests in watch mode (for development) npm run test:watch # Run performance tests specifically npm test -- --grep "Performance Tests" # Build and test npm run build && npm test ``` #### Using Bun ```bash # Run all tests bun test # Run tests in watch mode (for development) bun test --watch # Run performance tests specifically bun test --grep "Performance Tests" # Build and test bun run build && bun test ``` ### Test Quality Standards - **100% Pass Rate:** All tests must pass before any code changes are merged - **Comprehensive Coverage:** Tests cover happy paths, edge cases, and error scenarios - **Performance Testing:** Includes tests for large datasets and memory efficiency - **Type Safety:** Full TypeScript type checking and interface validation - **Real-world Scenarios:** Tests with actual datasets (tic-tac-toe, voting records, etc.) ## Development ### Building from Source This project is written in TypeScript. To build from source: #### Using npm ```bash # Install dependencies npm install # Build the project npm run build # Run tests npm test # Watch mode for development npm run build:watch ``` #### Using Bun ```bash # Install dependencies bun install # Build the project bun run build # Run tests bun run test # Watch mode for development bun run build:watch ``` ## Windows Users If you encounter issues with `npm test`, this project uses cross-env for cross-platform compatibility. The setup should work automatically, but if you encounter issues: 1. Ensure you're using Git Bash or WSL 2. Or use PowerShell/Command Prompt after running `npm install` ### Project Structure - `src/` - TypeScript source files - `lib/` - Compiled JavaScript output (generated) - `tst/` - TypeScript test files - `data/` - Sample datasets for testing ### Contributing We welcome contributions to improve this machine learning library! Please see our [Contributing Guide](CONTRIBUTING.md) for detailed information on how to contribute. **Quick Start for Contributors:** 1. Fork the repository 2. Create a feature branch (`git checkout -b feature/amazing-feature`) 3. Make your changes in the `src/` directory 4. Add comprehensive tests in the `tst/` directory 5. Run tests to ensure all pass (`npm test` or `bun test`) 6. Commit your changes (`git commit -m 'feat: add amazing feature'`) 7. Push to your branch (`git push origin feature/amazing-feature`) 8. Open a Pull Request **Key Requirements:** - ✅ All 408 tests must pass - ✅ TypeScript compliance and proper typing - ✅ Comprehensive test coverage for new features - ✅ Performance considerations for large datasets - ✅ Clear documentation and commit messages For detailed guidelines, code style, and testing requirements, please see [CONTRIBUTING.md](CONTRIBUTING.md). ## Why Node.js 20+ or Bun 1.0+? This package requires Node.js 20+ or Bun 1.0+ because: - **ES Modules:** Uses native ES module support (`"type": "module"`) - **Modern Features:** Leverages ES2022 features for better performance - **Import Assertions:** Uses modern import syntax for better compatibility - **Performance:** Takes advantage of Node.js 20+ or Bun 1.0+ optimizations ### Bun Compatibility Bun is fully supported and offers several advantages: - **Faster Installation:** Bun's package manager is significantly faster than npm - **Built-in TypeScript:** No need for ts-node or additional TypeScript tooling - **Faster Test Execution:** Bun's test runner is optimized for speed - **Better Performance:** Generally faster execution for JavaScript/TypeScript code