UNPKG

ml-naivebayes

Version:
172 lines (150 loc) 5.14 kB
import { Matrix } from 'ml-matrix'; import { separateClasses } from './utils'; export class GaussianNB { /** * Constructor for the Gaussian Naive Bayes classifier, the parameters here is just for loading purposes. * @constructor * @param {boolean} reload * @param {object} model */ constructor(reload, model) { if (reload) { this.means = model.means; this.calculateProbabilities = model.calculateProbabilities; } } /** * Function that trains the classifier with a matrix that represents the training set and an array that * represents the label of each row in the training set. the labels must be numbers between 0 to n-1 where * n represents the number of classes. * * WARNING: in the case that one class, all the cases in one or more features have the same value, the * Naive Bayes classifier will not work well. * @param {Matrix|Array} trainingSet * @param {Matrix|Array} trainingLabels */ train(trainingSet, trainingLabels) { var C1 = Math.sqrt(2 * Math.PI); // constant to precalculate the squared root trainingSet = Matrix.checkMatrix(trainingSet); if (trainingSet.rows !== trainingLabels.length) { throw new RangeError( 'the size of the training set and the training labels must be the same.' ); } var separatedClasses = separateClasses(trainingSet, trainingLabels); var calculateProbabilities = new Array(separatedClasses.length); this.means = new Array(separatedClasses.length); for (var i = 0; i < separatedClasses.length; ++i) { var means = separatedClasses[i].mean('column'); var std = separatedClasses[i].standardDeviation('column', { mean: means }); var logPriorProbability = Math.log( separatedClasses[i].rows / trainingSet.rows ); calculateProbabilities[i] = new Array(means.length + 1); calculateProbabilities[i][0] = logPriorProbability; for (var j = 1; j < means.length + 1; ++j) { var currentStd = std[j - 1]; calculateProbabilities[i][j] = [ 1 / (C1 * currentStd), -2 * currentStd * currentStd ]; } this.means[i] = means; } this.calculateProbabilities = calculateProbabilities; } /** * function that predicts each row of the dataset (must be a matrix). * * @param {Matrix|Array} dataset * @return {Array} */ predict(dataset) { dataset = Matrix.checkMatrix(dataset); if (dataset.rows === this.calculateProbabilities[0].length) { throw new RangeError( 'the dataset must have the same features as the training set' ); } var predictions = new Array(dataset.rows); for (var i = 0; i < predictions.length; ++i) { predictions[i] = getCurrentClass( dataset.getRow(i), this.means, this.calculateProbabilities ); } return predictions; } /** * Function that export the NaiveBayes model. * @return {object} */ toJSON() { return { modelName: 'NaiveBayes', means: this.means, calculateProbabilities: this.calculateProbabilities }; } /** * Function that create a GaussianNB classifier with the given model. * @param {object} model * @return {GaussianNB} */ static load(model) { if (model.modelName !== 'NaiveBayes') { throw new RangeError( 'The current model is not a Multinomial Naive Bayes, current model:', model.name ); } return new GaussianNB(true, model); } } /** * @private * Function the retrieves a prediction with one case. * * @param {Array} currentCase * @param {Array} mean - Precalculated means of each class trained * @param {Array} classes - Precalculated value of each class (Prior probability and probability function of each feature) * @return {number} */ function getCurrentClass(currentCase, mean, classes) { var maxProbability = 0; var predictedClass = -1; // going through all precalculated values for the classes for (var i = 0; i < classes.length; ++i) { var currentProbability = classes[i][0]; // initialize with the prior probability for (var j = 1; j < classes[0][1].length + 1; ++j) { currentProbability += calculateLogProbability( currentCase[j - 1], mean[i][j - 1], classes[i][j][0], classes[i][j][1] ); } currentProbability = Math.exp(currentProbability); if (currentProbability > maxProbability) { maxProbability = currentProbability; predictedClass = i; } } return predictedClass; } /** * @private * function that retrieves the probability of the feature given the class. * @param {number} value - value of the feature. * @param {number} mean - mean of the feature for the given class. * @param {number} C1 - precalculated value of (1 / (sqrt(2*pi) * std)). * @param {number} C2 - precalculated value of (2 * std^2) for the denominator of the exponential. * @return {number} */ function calculateLogProbability(value, mean, C1, C2) { value = value - mean; return Math.log(C1 * Math.exp((value * value) / C2)); }