UNPKG

@jsmlt/jsmlt

Version:

JavaScript Machine Learning

129 lines (100 loc) 5.36 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.accuracy = accuracy; exports.auroc = auroc; var Arrays = _interopRequireWildcard(require("../../arrays")); function _getRequireWildcardCache() { if (typeof WeakMap !== "function") return null; var cache = new WeakMap(); _getRequireWildcardCache = function _getRequireWildcardCache() { return cache; }; return cache; } function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } var cache = _getRequireWildcardCache(); if (cache && cache.has(obj)) { return cache.get(obj); } var newObj = {}; if (obj != null) { var hasPropertyDescriptor = Object.defineProperty && Object.getOwnPropertyDescriptor; for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = hasPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : null; if (desc && (desc.get || desc.set)) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj["default"] = obj; if (cache) { cache.set(obj, newObj); } return newObj; } // Standard imports /** * Evaluate the accuracy of a set of predictions. * * @param {Array.<mixed>} yTrue - True labels * @param {Array.<mixed>} yPred - Predicted labels * @param {boolean} [normalize = true] - Whether to normalize the accuracy to a range between * 0 and 1. In this context, 0 means no predictions were correct, and 1 means all predictions * were correct. If set to false, the integer number of correct predictions is returned * @return {number} Proportion of correct predictions (if normalize=true) or integer number of * correct predictions (if normalize=false) */ function accuracy(yTrue, yPred) { var normalize = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true; // Check input lengths if (yTrue.length !== yPred.length) { throw new Error('Number of true labels must match number of predicted labels.'); } // Count the number of correctly classified points var numCorrect = yTrue.reduce(function (r, a, i) { return r + (a === yPred[i] ? 1 : 0); }, 0); // If specified, normalize the accuracy to a number between 0 and 1 if (normalize) { return numCorrect / yTrue.length; } return numCorrect; } /** * Calculate the area under the receiver-operator characteristic curve (AUROC) for a set of * predictions. Area is calculated using the Trapezoidal rule. * * @param {Array.<number>} yTrue - True labels. Must contain only integers 0 and 1 * @param {Array.<mixed>} yPred - Predicted label confidences. Must be between 0 (fully confident * in negative prediction) and 1 (fully confident in positive prediction), both inclusive * @return {number} Calculated AUROC */ function auroc(yTrue, yPred) { // Check input lengths if (yTrue.length !== yPred.length) { throw new Error('Number of true labels must match number of predicted labels.'); } // Check number of classes var numClasses = Arrays.unique(yTrue).length; if (numClasses !== 2) { throw new Error('Number of classes in true label vector must be exactly 2.'); } // Check class labels if (!yTrue.includes(0) || !yTrue.includes(1)) { throw new Error('True labels must be integers 0 and 1.'); } // Check prediction confidence values if (!yPred.every(function (x) { return x >= 0 && x <= 1; })) { throw new Error('Prediction confidence values must be between 0 and 1 (inclusive).'); } // Sort the prediction probabilities descendingly to get a list of all possible thresholds var sortedIndices = Arrays.argSort(yPred, function (a, b) { return b - a; }); // To find the false positive rate and true positive rate, we need to know the number of negatives // and positives, respectively, in the true labels list var numNegative = yTrue.filter(function (x) { return x === 0; }).length; var numPositive = yTrue.filter(function (x) { return x === 1; }).length; // Keep track of number of false positives and true positives. Initialize them with threshold 1, // such that all examples are predicted to be negative var fp = 0; var tp = 0; // List of false positive rates and true positive rates. The false positive rate and true positive // rate at all indices i form pairs var fprs = [0]; var tprs = [0]; // Loop over all possible thresholds and calculate the tpr/fpr var thresholdIndexPrevious = -1; sortedIndices.forEach(function (thresholdIndex) { if (yTrue[thresholdIndex] === 0) { fp += 1; } else { tp += 1; } if (thresholdIndexPrevious >= 0 && yPred[thresholdIndex] === yPred[thresholdIndexPrevious]) { fprs.splice(-1, 1); tprs.splice(-1, 1); } fprs.push(fp / numNegative); tprs.push(tp / numPositive); thresholdIndexPrevious = thresholdIndex; }); // The area under the ROC curve is calculated by taking the area under each pair of points that // follow each other on the x-axis. For each pair of points, the area under the trapezoid spanned // by the points and the corresponding points on the x-axis is used as the area. var fprsDiff = Arrays.sum(Arrays.scale(fprs.slice(0, -1), -1), fprs.slice(1)); var tprsMean = Arrays.scale(Arrays.sum(tprs.slice(0, -1), tprs.slice(1)), 0.5); return fprsDiff.reduce(function (r, a, i) { return r + a * tprsMean[i]; }, 0); }