@jsmlt/jsmlt
Version:
JavaScript Machine Learning
103 lines (84 loc) • 3.66 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.accuracy = accuracy;
exports.auroc = auroc;
var _arrays = require('../../arrays');
var Arrays = _interopRequireWildcard(_arrays);
function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } }
/**
* Evaluate the accuracy of a set of predictions.
*
* @param {Array.<mixed>} yTrue - True labels
* @param {Array.<mixed>} yPred - Predicted labels
* @param {boolean} [normalize = true] - Whether to normalize the accuracy to a range between
* 0 and 1. In this context, 0 means no predictions were correct, and 1 means all predictions
* were correct. If set to false, the integer number of correct predictions is returned
* @return {number} Proportion of correct predictions (if normalize=true) or integer number of
* correct predictions (if normalize=false)
*/
function accuracy(yTrue, yPred) {
var normalize = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true;
if (yTrue.length !== yPred.length) {
throw new Error('Number of true labels must match number of predicted labels.');
}
// Count the number of correctly classified points
var numCorrect = yTrue.reduce(function (r, a, i) {
return r + (a === yPred[i] ? 1 : 0);
}, 0);
// If specified, normalize the accuracy to a number between 0 and 1
if (normalize) {
return numCorrect / yTrue.length;
}
return numCorrect;
}
/**
* Calculate the area under the receiver-operator curve (AUROC) for a set of predictions. Area is
* calculated using the Trapezoidal rule.
*
* @param {Array.<mixed>} yTrue - True labels
* @param {Array.<mixed>} yPred - Predicted labels
* @return {number} Calculated AUROC
*/
// Standard imports
function auroc(yTrue, yPred) {
// Sort the prediction probabilities to get a list of all possible thresholds
var sortedIndices = Arrays.argSort(yPred, function (a, b) {
return a - b;
});
// To find the false positive rate and true positive rate, we need to know the number of negatives
// and positives, respectively, in the true labels list
var numNegative = yTrue.filter(function (x) {
return x === 0;
}).length;
var numPositive = yTrue.filter(function (x) {
return x === 1;
}).length;
// Keep track of number of false positives and true positives. Initialize them with threshold 0
var fp = numNegative;
var tp = numPositive;
// List of fals positive rates and true positive rates. The false positive rate and true positive
// rate at all indices i form pairs
var fprs = [1];
var tprs = [1];
// Loop over all possible thresholds and calculate the tpr/fpr
sortedIndices.forEach(function (thresholdIndex) {
if (yTrue[thresholdIndex] === 0) {
fp -= 1;
} else {
tp -= 1;
}
// Add the newly calculated tpr/fpr pair to the lists
fprs.push(fp / numNegative);
tprs.push(tp / numPositive);
});
// The area under the ROC curve is calculated by taking the area under each pair of points that
// follow each other on the x-axis. For each pair of points, the area under the trapezoid spanned
// by the points and the corresponding points on the x-axis is used as the area.
var fprsDiff = Arrays.sum(fprs.slice(0, -1), Arrays.scale(fprs.slice(1), -1));
var tprsMean = Arrays.scale(Arrays.sum(tprs.slice(0, -1), tprs.slice(1)), 0.5);
return fprsDiff.reduce(function (r, a, i) {
return r + a * tprsMean[i];
}, 0);
}