@jsmlt/jsmlt
Version:
JavaScript Machine Learning
129 lines (100 loc) • 5.36 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.accuracy = accuracy;
exports.auroc = auroc;
var Arrays = _interopRequireWildcard(require("../../arrays"));
function _getRequireWildcardCache() { if (typeof WeakMap !== "function") return null; var cache = new WeakMap(); _getRequireWildcardCache = function _getRequireWildcardCache() { return cache; }; return cache; }
function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } var cache = _getRequireWildcardCache(); if (cache && cache.has(obj)) { return cache.get(obj); } var newObj = {}; if (obj != null) { var hasPropertyDescriptor = Object.defineProperty && Object.getOwnPropertyDescriptor; for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = hasPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : null; if (desc && (desc.get || desc.set)) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj["default"] = obj; if (cache) { cache.set(obj, newObj); } return newObj; }
// Standard imports
/**
* Evaluate the accuracy of a set of predictions.
*
* @param {Array.<mixed>} yTrue - True labels
* @param {Array.<mixed>} yPred - Predicted labels
* @param {boolean} [normalize = true] - Whether to normalize the accuracy to a range between
* 0 and 1. In this context, 0 means no predictions were correct, and 1 means all predictions
* were correct. If set to false, the integer number of correct predictions is returned
* @return {number} Proportion of correct predictions (if normalize=true) or integer number of
* correct predictions (if normalize=false)
*/
function accuracy(yTrue, yPred) {
var normalize = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
// Check input lengths
if (yTrue.length !== yPred.length) {
throw new Error('Number of true labels must match number of predicted labels.');
} // Count the number of correctly classified points
var numCorrect = yTrue.reduce(function (r, a, i) {
return r + (a === yPred[i] ? 1 : 0);
}, 0); // If specified, normalize the accuracy to a number between 0 and 1
if (normalize) {
return numCorrect / yTrue.length;
}
return numCorrect;
}
/**
* Calculate the area under the receiver-operator characteristic curve (AUROC) for a set of
* predictions. Area is calculated using the Trapezoidal rule.
*
* @param {Array.<number>} yTrue - True labels. Must contain only integers 0 and 1
* @param {Array.<mixed>} yPred - Predicted label confidences. Must be between 0 (fully confident
* in negative prediction) and 1 (fully confident in positive prediction), both inclusive
* @return {number} Calculated AUROC
*/
function auroc(yTrue, yPred) {
// Check input lengths
if (yTrue.length !== yPred.length) {
throw new Error('Number of true labels must match number of predicted labels.');
} // Check number of classes
var numClasses = Arrays.unique(yTrue).length;
if (numClasses !== 2) {
throw new Error('Number of classes in true label vector must be exactly 2.');
} // Check class labels
if (!yTrue.includes(0) || !yTrue.includes(1)) {
throw new Error('True labels must be integers 0 and 1.');
} // Check prediction confidence values
if (!yPred.every(function (x) {
return x >= 0 && x <= 1;
})) {
throw new Error('Prediction confidence values must be between 0 and 1 (inclusive).');
} // Sort the prediction probabilities descendingly to get a list of all possible thresholds
var sortedIndices = Arrays.argSort(yPred, function (a, b) {
return b - a;
}); // To find the false positive rate and true positive rate, we need to know the number of negatives
// and positives, respectively, in the true labels list
var numNegative = yTrue.filter(function (x) {
return x === 0;
}).length;
var numPositive = yTrue.filter(function (x) {
return x === 1;
}).length; // Keep track of number of false positives and true positives. Initialize them with threshold 1,
// such that all examples are predicted to be negative
var fp = 0;
var tp = 0; // List of false positive rates and true positive rates. The false positive rate and true positive
// rate at all indices i form pairs
var fprs = [0];
var tprs = [0]; // Loop over all possible thresholds and calculate the tpr/fpr
var thresholdIndexPrevious = -1;
sortedIndices.forEach(function (thresholdIndex) {
if (yTrue[thresholdIndex] === 0) {
fp += 1;
} else {
tp += 1;
}
if (thresholdIndexPrevious >= 0 && yPred[thresholdIndex] === yPred[thresholdIndexPrevious]) {
fprs.splice(-1, 1);
tprs.splice(-1, 1);
}
fprs.push(fp / numNegative);
tprs.push(tp / numPositive);
thresholdIndexPrevious = thresholdIndex;
}); // The area under the ROC curve is calculated by taking the area under each pair of points that
// follow each other on the x-axis. For each pair of points, the area under the trapezoid spanned
// by the points and the corresponding points on the x-axis is used as the area.
var fprsDiff = Arrays.sum(Arrays.scale(fprs.slice(0, -1), -1), fprs.slice(1));
var tprsMean = Arrays.scale(Arrays.sum(tprs.slice(0, -1), tprs.slice(1)), 0.5);
return fprsDiff.reduce(function (r, a, i) {
return r + a * tprsMean[i];
}, 0);
}