UNPKG

@jsmlt/jsmlt

Version:

JavaScript Machine Learning

308 lines (253 loc) 12.2 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.OneVsAllClassifier = exports.Classifier = exports.Estimator = void 0; var Arrays = _interopRequireWildcard(require("../arrays")); function _getRequireWildcardCache() { if (typeof WeakMap !== "function") return null; var cache = new WeakMap(); _getRequireWildcardCache = function _getRequireWildcardCache() { return cache; }; return cache; } function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } var cache = _getRequireWildcardCache(); if (cache && cache.has(obj)) { return cache.get(obj); } var newObj = {}; if (obj != null) { var hasPropertyDescriptor = Object.defineProperty && Object.getOwnPropertyDescriptor; for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = hasPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : null; if (desc && (desc.get || desc.set)) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj["default"] = obj; if (cache) { cache.set(obj, newObj); } return newObj; } function _typeof(obj) { if (typeof Symbol === "function" && typeof Symbol.iterator === "symbol") { _typeof = function _typeof(obj) { return typeof obj; }; } else { _typeof = function _typeof(obj) { return obj && typeof Symbol === "function" && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj; }; } return _typeof(obj); } function _possibleConstructorReturn(self, call) { if (call && (_typeof(call) === "object" || typeof call === "function")) { return call; } return _assertThisInitialized(self); } function _assertThisInitialized(self) { if (self === void 0) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return self; } function _getPrototypeOf(o) { _getPrototypeOf = Object.setPrototypeOf ? Object.getPrototypeOf : function _getPrototypeOf(o) { return o.__proto__ || Object.getPrototypeOf(o); }; return _getPrototypeOf(o); } function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function"); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, writable: true, configurable: true } }); if (superClass) _setPrototypeOf(subClass, superClass); } function _setPrototypeOf(o, p) { _setPrototypeOf = Object.setPrototypeOf || function _setPrototypeOf(o, p) { o.__proto__ = p; return o; }; return _setPrototypeOf(o, p); } function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } function _defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } function _createClass(Constructor, protoProps, staticProps) { if (protoProps) _defineProperties(Constructor.prototype, protoProps); if (staticProps) _defineProperties(Constructor, staticProps); return Constructor; } /** * Base class for supervised estimators (classifiers or regression models). */ var Estimator = /*#__PURE__*/ function () { function Estimator() { _classCallCheck(this, Estimator); } _createClass(Estimator, [{ key: "train", /** * Train the supervised learning algorithm on a dataset. * * @abstract * * @param {Array.<Array.<number>>} X - Features per data point * @param {Array.<mixed>} y Class labels per data point */ value: function train(X, y) { throw new Error('Method must be implemented child class.'); } /** * Make a prediction for a data set. * * @abstract * * @param {Array.<Array.<number>>} X - Features for each data point * @return {Array.<mixed>} Predictions. Label of class with highest prevalence among k nearest * neighbours for each sample */ }, { key: "predict", value: function predict(X) { throw new Error('Method must be implemented child class.'); } }]); return Estimator; }(); /** * Base class for classifiers. */ exports.Estimator = Estimator; var Classifier = /*#__PURE__*/ function (_Estimator) { _inherits(Classifier, _Estimator); function Classifier() { _classCallCheck(this, Classifier); return _possibleConstructorReturn(this, _getPrototypeOf(Classifier).apply(this, arguments)); } return Classifier; }(Estimator); /** * Base class for multiclass classifiers using the one-vs-all classification method. For a training * set with k unique class labels, the one-vs-all classifier creates k binary classifiers. Each of * these classifiers is trained on the entire data set, where the i-th classifier treats all samples * that do not come from the i-th class as being from the same class. In the prediction phase, the * one-vs-all classifier runs all k binary classifiers on the test data point, and predicts the * class that has the highest normalized prediction value */ exports.Classifier = Classifier; var OneVsAllClassifier = /*#__PURE__*/ function (_Classifier) { _inherits(OneVsAllClassifier, _Classifier); function OneVsAllClassifier() { _classCallCheck(this, OneVsAllClassifier); return _possibleConstructorReturn(this, _getPrototypeOf(OneVsAllClassifier).apply(this, arguments)); } _createClass(OneVsAllClassifier, [{ key: "createClassifier", /** * Create a binary classifier for one of the classes. * * @abstract * * @param {number} classIndex - Class index of the positive class for the binary classifier * @return {BinaryClassifier} Binary classifier */ value: function createClassifier(classIndex) { throw new Error('Method must be implemented child class.'); } /** * Create all binary classifiers. Creates one classifier per class. * * @param {Array.<number>} y - Class labels for the training data */ }, { key: "createClassifiers", value: function createClassifiers(y) { var _this = this; // Get unique labels var uniqueClassIndices = Arrays.unique(y); // Initialize label set and classifier for all labels this.classifiers = uniqueClassIndices.map(function (classIndex) { var classifier = _this.createClassifier(); return { classIndex: classIndex, classifier: classifier }; }); } /** * Get the class labels corresponding with each internal class label. Can be used to determine * which predictino is for which class in predictProba. * * @return {Array.<number>} The n-th element in this array contains the class label of what is * internally class n */ }, { key: "getClasses", value: function getClasses() { return this.classifiers.map(function (x, i) { return x; }); } /** * Train all binary classifiers one-by-one * * @param {Array.<Array.<number>>} X - Features per data point * @param {Array.<mixed>} y Class labels per data point */ }, { key: "trainBatch", value: function trainBatch(X, y) { this.classifiers.forEach(function (classifier) { var yOneVsAll = y.map(function (classIndex) { return classIndex === classifier.classIndex ? 1 : 0; }); classifier.classifier.train(X, yOneVsAll); }); } /** * Train all binary classifiers iteration by iteration, i.e. start with the first training * iteration for each binary classifier, then execute the second training iteration for each * binary classifier, and so forth. Can be used when one needs to keep track of information per * iteration, e.g. accuracy */ }, { key: "trainIterative", value: function trainIterative() { var remainingClassIndices = Arrays.unique(this.training.labels); var epoch = 0; while (epoch < 100 && remainingClassIndices.length > 0) { var remainingClassIndicesNew = remainingClassIndices.slice(); // Loop over all 1-vs-all classifiers var _iteratorNormalCompletion = true; var _didIteratorError = false; var _iteratorError = undefined; try { for (var _iterator = remainingClassIndices[Symbol.iterator](), _step; !(_iteratorNormalCompletion = (_step = _iterator.next()).done); _iteratorNormalCompletion = true) { var classIndex = _step.value; // Run a single iteration for the classifier this.classifiers[classIndex].trainIteration(); if (this.classifiers[classIndex].checkConvergence()) { remainingClassIndicesNew.splice(remainingClassIndicesNew.indexOf(classIndex), 1); } } } catch (err) { _didIteratorError = true; _iteratorError = err; } finally { try { if (!_iteratorNormalCompletion && _iterator["return"] != null) { _iterator["return"](); } } finally { if (_didIteratorError) { throw _iteratorError; } } } remainingClassIndices = remainingClassIndicesNew; // Emit event the outside can hook into this.emit('iterationCompleted'); epoch += 1; } // Emit event the outside can hook into this.emit('converged'); } /** * @see {Classifier#predict} */ }, { key: "predict", value: function predict(X) { // Get predictions from all classifiers for all data points by predicting all data points with // each classifier (getting an array of predictions for each classifier) and transposing var datapointsPredictions = Arrays.transpose(this.classifiers.map(function (classifier) { return classifier.classifier.predict(X, { output: 'normalized' }); })); // Form final prediction by taking index of maximum normalized classifier output return datapointsPredictions.map(function (x) { return Arrays.argMax(x); }); } /** * Make a probabilistic prediction for a data set. * * @param {Array.Array.<number>} X - Features for each data point * @return {Array.Array.<number>} Probability predictions. Each array element contains the * probability of that particular class. The array elements are ordered in the order the classes * appear in the training data (i.e., if class "A" occurs first in the labels list in the * training, procedure, its probability is returned in the first array element of each * sub-array) */ }, { key: "predictProba", value: function predictProba(X) { if (typeof this.classifiers[0].classifier.predictProba !== 'function') { throw new Error('Base classifier does not implement the predictProba method, which was attempted to be called from the one-vs-all classifier.'); } // Get probability predictions from all classifiers for all data points by predicting all data // points with each classifier (getting an array of predictions for each classifier) and // transposing var predictions = Arrays.transpose(this.classifiers.map(function (classifier) { return classifier.classifier.predictProba(X).map(function (probs) { return probs[1]; }); })); // Scale all predictions to yield valid probabilities return predictions.map(function (x) { return Arrays.scale(x, 1 / Arrays.internalSum(x)); }); } /** * Retrieve the individual binary one-vs-all classifiers. * * @return {Array.<Classifier>} List of binary one-vs-all classifiers used as the base classifiers * for this multiclass classifier */ }, { key: "getClassifiers", value: function getClassifiers() { return this.classifiers; } }]); return OneVsAllClassifier; }(Classifier); exports.OneVsAllClassifier = OneVsAllClassifier;