decision-tree
Version:
NodeJS implementation of decision tree, random forest, and XGBoost algorithms with comprehensive performance testing (Node.js 20+)
480 lines (415 loc) • 14.8 kB
text/typescript
/**
* XGBoost Algorithm
* @module XGBoost
*/
import _ from 'lodash';
import DecisionTree from './decision-tree.js';
import {
TreeNode,
DecisionTreeData,
TrainingData,
XGBoostConfig,
XGBoostData,
BoostingHistory,
NODE_TYPES
} from './shared/types.js';
import {
SeededRandom
} from './shared/utils.js';
import { LossFunctionFactory } from './shared/loss-functions.js';
import {
createWeightedTree,
createWeightedSample,
calculateBaseScore
} from './shared/gradient-boosting.js';
/**
* XGBoost class implementing gradient boosting with decision trees
*/
class XGBoost {
public static readonly NODE_TYPES = NODE_TYPES;
private trees: DecisionTree[] = [];
private data: any[] = [];
private target!: string;
private features!: string[];
private config: XGBoostConfig;
private baseScore: number = 0;
private bestIteration: number = 0;
private boostingHistory: BoostingHistory = {
trainLoss: [],
validationLoss: [],
iterations: []
};
constructor(...args: any[]) {
const numArgs = args.length;
// Default configuration
this.config = {
nEstimators: 100,
learningRate: 0.1,
maxDepth: 6,
minChildWeight: 1,
subsample: 1,
colsampleByTree: 1,
regAlpha: 0,
regLambda: 1,
objective: 'regression',
earlyStoppingRounds: undefined,
randomState: undefined,
validationFraction: 0.2
};
if (numArgs === 1) {
this.import(args[0]);
}
else if (numArgs === 2) {
const [target, features] = args;
if (!target || typeof target !== 'string') {
throw new Error('`target` argument is expected to be a String. Check documentation on usage');
}
if (!features || !Array.isArray(features)) {
throw new Error('`features` argument is expected to be an Array<String>. Check documentation on usage');
}
this.target = target;
this.features = features;
}
else if (numArgs === 3) {
// Check if third argument is an array (data) or object (config)
if (Array.isArray(args[2])) {
// [data, target, features] pattern
const [data, target, features] = args;
const instance = new XGBoost(target, features);
instance.train(data);
return instance;
} else {
// [target, features, config] pattern
const [target, features, config] = args;
if (!target || typeof target !== 'string') {
throw new Error('`target` argument is expected to be a String. Check documentation on usage');
}
if (!features || !Array.isArray(features)) {
throw new Error('`features` argument is expected to be an Array<String>. Check documentation on usage');
}
if (config && typeof config === 'object') {
this.config = { ...this.config, ...config };
}
this.target = target;
this.features = features;
}
}
else if (numArgs === 4) {
const [data, target, features, config] = args;
const instance = new XGBoost(target, features, config);
instance.train(data);
return instance;
}
else {
throw new Error('Invalid arguments passed to constructor. Check documentation on usage');
}
}
/**
* Trains the XGBoost model with provided data
* @param data - Array of training data objects
*/
train(data: TrainingData[]): void {
if (!data || !Array.isArray(data)) {
throw new Error('`data` argument is expected to be an Array<Object>. Check documentation on usage');
}
if (data.length === 0) {
throw new Error('`data` argument is expected to be an Array<Object>. Check documentation on usage');
}
this.data = data;
this.trees = [];
this.boostingHistory = {
trainLoss: [],
validationLoss: [],
iterations: []
};
const random = new SeededRandom(this.config.randomState || Math.floor(Math.random() * 1000000));
const nEstimators = this.config.nEstimators !== undefined ? this.config.nEstimators : 100;
const objective = this.config.objective || 'regression';
const learningRate = this.config.learningRate || 0.1;
// Calculate base score
this.baseScore = calculateBaseScore(data, this.target, objective);
// Split data for validation if early stopping is enabled
let trainData = data;
let validationData: TrainingData[] = [];
if (this.config.earlyStoppingRounds && this.config.validationFraction) {
const validationSize = Math.floor(data.length * this.config.validationFraction);
const shuffledIndices = Array.from({ length: data.length }, (_, i) => i);
// Shuffle indices
for (let i = shuffledIndices.length - 1; i > 0; i--) {
const j = random.nextInt(i + 1);
[shuffledIndices[i], shuffledIndices[j]] = [shuffledIndices[j], shuffledIndices[i]];
}
const validationIndices = shuffledIndices.slice(0, validationSize);
const trainIndices = shuffledIndices.slice(validationSize);
validationData = validationIndices.map(i => data[i]);
trainData = trainIndices.map(i => data[i]);
}
// Initialize predictions
let predictions = new Array(trainData.length).fill(this.baseScore);
let validationPredictions = new Array(validationData.length).fill(this.baseScore);
// Get loss function
const LossFunction = LossFunctionFactory.create(objective);
// Boosting iterations
let bestValidationLoss = Infinity;
let noImprovementCount = 0;
for (let i = 0; i < nEstimators; i++) {
// Calculate gradients and hessians
const targetValues = trainData.map(sample => sample[this.target]);
let gradient: number[];
let hessian: number[];
if (objective === 'multiclass') {
// For multiclass, we need to handle it differently
gradient = new Array(predictions.length).fill(0);
hessian = new Array(predictions.length).fill(1);
} else {
const result = LossFunction.calculateGradientsAndHessians(predictions, targetValues);
gradient = result.gradient;
hessian = result.hessian;
}
// Create weighted sample
const weightedSample = createWeightedSample(trainData, this.config, random);
weightedSample.gradients = gradient;
weightedSample.hessians = hessian;
// Build tree
const tree = createWeightedTree(
weightedSample.data,
this.target,
this.features,
weightedSample.weights,
weightedSample.gradients,
weightedSample.hessians,
this.config
);
// Create DecisionTree instance for consistency
const treeData: DecisionTreeData = {
model: tree,
data: weightedSample.data,
target: this.target,
features: this.features
};
const decisionTree = new DecisionTree(treeData);
this.trees.push(decisionTree);
// Update predictions
for (let j = 0; j < trainData.length; j++) {
const treePrediction = decisionTree.predict(trainData[j]);
predictions[j] += learningRate * treePrediction;
}
// Update validation predictions
for (let j = 0; j < validationData.length; j++) {
const treePrediction = decisionTree.predict(validationData[j]);
validationPredictions[j] += learningRate * treePrediction;
}
// Calculate losses
const trainLoss = LossFunction.calculateLoss(predictions, targetValues);
this.boostingHistory.trainLoss.push(trainLoss);
this.boostingHistory.iterations.push(i + 1);
if (validationData.length > 0) {
const validationTargetValues = validationData.map(sample => sample[this.target]);
const validationLoss = LossFunction.calculateLoss(validationPredictions, validationTargetValues);
this.boostingHistory.validationLoss.push(validationLoss);
// Early stopping check
if (this.config.earlyStoppingRounds) {
if (validationLoss < bestValidationLoss) {
bestValidationLoss = validationLoss;
this.bestIteration = i + 1;
noImprovementCount = 0;
} else {
noImprovementCount++;
if (noImprovementCount >= this.config.earlyStoppingRounds) {
break;
}
}
}
}
}
// If no early stopping, best iteration is the last one
if (!this.config.earlyStoppingRounds) {
this.bestIteration = this.trees.length;
}
}
/**
* Predicts class/value for a given sample
* @param sample - Sample data to predict
* @returns Predicted value
*/
predict(sample: TrainingData): any {
if (this.trees.length === 0) {
throw new Error('XGBoost has not been trained yet. Call train() first.');
}
if (!sample || typeof sample !== 'object' || Array.isArray(sample)) {
throw new Error('Sample must be an object');
}
let prediction = this.baseScore;
const learningRate = this.config.learningRate || 0.1;
for (let i = 0; i < this.bestIteration; i++) {
const treePrediction = this.trees[i].predict(sample);
prediction += learningRate * treePrediction;
}
// Apply objective-specific transformation
const objective = this.config.objective || 'regression';
if (objective === 'binary') {
// Convert to probability using sigmoid
const clampedPrediction = Math.max(-500, Math.min(500, prediction));
const probability = 1 / (1 + Math.exp(-clampedPrediction));
return probability > 0.5 ? true : false;
}
return prediction;
}
/**
* Evaluates prediction accuracy on samples
* @param samples - Array of test samples
* @returns Accuracy ratio (correct predictions / total predictions)
*/
evaluate(samples: TrainingData[]): number {
let total = 0;
let correct = 0;
_.each(samples, (s) => {
total++;
let pred = this.predict(s);
let actual = s[this.target];
// Handle different data types
if (typeof actual === 'boolean') {
pred = Boolean(pred);
} else if (typeof actual === 'number') {
pred = Number(pred);
}
if (_.isEqual(pred, actual)) {
correct++;
}
});
return correct / total;
}
/**
* Imports a previously saved model
* @param json - JSON representation of the model
*/
import(json: XGBoostData): void {
const {trees, target, features, config, data, baseScore, bestIteration, boostingHistory} = json;
if (!trees || !Array.isArray(trees)) {
throw new Error('Invalid model: trees property is required and must be an array');
}
if (!target || typeof target !== 'string') {
throw new Error('Invalid model: target property is required and must be a string');
}
if (!features || !Array.isArray(features)) {
throw new Error('Invalid model: features property is required and must be an array');
}
if (!config || typeof config !== 'object') {
throw new Error('Invalid model: config property is required and must be an object');
}
if (!data || !Array.isArray(data)) {
throw new Error('Invalid model: data property is required and must be an array');
}
if (baseScore === undefined || typeof baseScore !== 'number') {
throw new Error('Invalid model: baseScore property is required and must be a number');
}
if (bestIteration === undefined || typeof bestIteration !== 'number') {
throw new Error('Invalid model: bestIteration property is required and must be a number');
}
if (!boostingHistory || typeof boostingHistory !== 'object') {
throw new Error('Invalid model: boostingHistory property is required and must be an object');
}
this.trees = trees.map(treeData => {
const tree = new DecisionTree(treeData);
return tree;
});
this.data = data;
this.target = target;
this.features = features;
this.config = config;
this.baseScore = baseScore || 0;
this.bestIteration = bestIteration || trees.length;
this.boostingHistory = boostingHistory || {
trainLoss: [],
validationLoss: [],
iterations: []
};
}
/**
* Returns JSON representation of trained model
* @returns JSON object containing model data
*/
toJSON(): XGBoostData {
const {data, target, features, config} = this;
const trees = this.trees.map(tree => tree.toJSON());
return {
trees,
data,
target,
features,
config,
baseScore: this.baseScore,
bestIteration: this.bestIteration,
boostingHistory: this.boostingHistory
};
}
/**
* Gets feature importance scores
* @returns Object with feature names as keys and importance scores as values
*/
getFeatureImportance(): { [feature: string]: number } {
if (this.trees.length === 0) {
throw new Error('XGBoost has not been trained yet. Call train() first.');
}
const importance: { [feature: string]: number } = {};
// Initialize all features with 0 importance
this.features.forEach(feature => {
importance[feature] = 0;
});
// Sum up importance from all trees
for (let i = 0; i < this.bestIteration; i++) {
const tree = this.trees[i];
const treeJson = tree.toJSON();
this.calculateTreeImportance(treeJson.model, importance);
}
return importance;
}
/**
* Recursively calculates feature importance from a tree node
* @private
*/
private calculateTreeImportance(node: TreeNode, importance: { [feature: string]: number }): void {
if (node.type === NODE_TYPES.FEATURE && node.gain && node.sampleSize) {
const feature = node.name;
const weightedGain = node.gain * node.sampleSize;
importance[feature] = (importance[feature] || 0) + weightedGain;
}
if (node.vals) {
node.vals.forEach(val => {
if (val.child) {
this.calculateTreeImportance(val.child, importance);
}
});
}
}
/**
* Gets the boosting history
* @returns Boosting history with losses and iterations
*/
getBoostingHistory(): BoostingHistory {
return { ...this.boostingHistory };
}
/**
* Gets the best iteration number
* @returns Best iteration number
*/
getBestIteration(): number {
return this.bestIteration;
}
/**
* Gets the number of trees in the model
* @returns Number of trees
*/
getTreeCount(): number {
return this.trees.length;
}
/**
* Gets the configuration used for this model
* @returns Configuration object
*/
getConfig(): XGBoostConfig {
return { ...this.config };
}
}
// Export the XGBoost class
export default XGBoost;