UNPKG

@jsmlt/jsmlt

Version:

JavaScript Machine Learning

235 lines (184 loc) 8.22 kB
'use strict'; Object.defineProperty(exports, "__esModule", { value: true }); exports.OneVsAllClassifier = exports.Classifier = exports.Estimator = undefined; var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); // Standard imports var _linalg = require('../math/linalg'); var LinAlg = _interopRequireWildcard(_linalg); var _arrays = require('../util/arrays'); var Arrays = _interopRequireWildcard(_arrays); function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } } function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; } function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; } function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } /** * Base class for supervised estimators (classifiers or regression models). */ var Estimator = exports.Estimator = function () { function Estimator() { _classCallCheck(this, Estimator); } _createClass(Estimator, [{ key: 'train', /** * Train the supervised learning algorithm on a dataset. * * @abstract * * @param {Array.<Array.<number>>} X - Features per data point * @param {Array.<mixed>} y Class labels per data point */ value: function train(X, y) { throw new Error('Method must be implemented child class.'); } /** * Make a prediction for a data set. * * @abstract * * @param {Array.<Array.<number>>} X - Features for each data point * @return {Array.<mixed>} Predictions. Label of class with highest prevalence among k nearest * neighbours for each sample */ }, { key: 'test', value: function test(X) { throw new Error('Method must be implemented child class.'); } }]); return Estimator; }(); /** * Base class for classifiers. */ var Classifier = exports.Classifier = function (_Estimator) { _inherits(Classifier, _Estimator); function Classifier() { _classCallCheck(this, Classifier); return _possibleConstructorReturn(this, (Classifier.__proto__ || Object.getPrototypeOf(Classifier)).apply(this, arguments)); } return Classifier; }(Estimator); var OneVsAllClassifier = exports.OneVsAllClassifier = function (_Classifier) { _inherits(OneVsAllClassifier, _Classifier); function OneVsAllClassifier() { _classCallCheck(this, OneVsAllClassifier); return _possibleConstructorReturn(this, (OneVsAllClassifier.__proto__ || Object.getPrototypeOf(OneVsAllClassifier)).apply(this, arguments)); } _createClass(OneVsAllClassifier, [{ key: 'createClassifier', /** * Create a binary classifier for one of the classes. * * @abstract * * @param {number} classIndex - Class index of the positive class for the binary classifier * @return {BinaryClassifier} Binary classifier */ value: function createClassifier(classIndex) { throw new Error('Method must be implemented child class.'); } /** * Create all binary classifiers. Creates one classifier per class. * * @param {Array.<number>} y - Class labels for the training data */ }, { key: 'createClassifiers', value: function createClassifiers(y) { var _this3 = this; // Get unique labels var uniqueClassIndices = Array.from(new Set(y)); // Initialize label set and classifier for all labels this.classifiers = uniqueClassIndices.map(function (classIndex) { var classifier = _this3.createClassifier(); return { classIndex: classIndex, classifier: classifier }; }); } /** * Train all binary classifiers one-by-one * * @param {Array.<Array.<number>>} X - Features per data point * @param {Array.<mixed>} y Class labels per data point */ }, { key: 'trainBatch', value: function trainBatch(X, y) { this.classifiers.forEach(function (classifier) { var yOneVsAll = y.map(function (classIndex) { return classIndex === classifier.classIndex ? 1 : 0; }); classifier.classifier.train(X, yOneVsAll); }); } /** * Train all binary classifiers iteration by iteration, i.e. start with the first training * iteration for each binary classifier, then execute the second training iteration for each * binary classifier, and so forth. Can be used when one needs to keep track of information per * iteration, e.g. accuracy */ }, { key: 'trainIterative', value: function trainIterative() { var remainingClassIndices = Array.from(new Set(this.training.labels)); var epoch = 0; while (epoch < 100 && remainingClassIndices.length > 0) { var remainingClassIndicesNew = remainingClassIndices.slice(); // Loop over all 1-vs-all classifiers var _iteratorNormalCompletion = true; var _didIteratorError = false; var _iteratorError = undefined; try { for (var _iterator = remainingClassIndices[Symbol.iterator](), _step; !(_iteratorNormalCompletion = (_step = _iterator.next()).done); _iteratorNormalCompletion = true) { var classIndex = _step.value; // Run a single iteration for the classifier this.classifiers[classIndex].trainIteration(); if (this.classifiers[classIndex].checkConvergence()) { remainingClassIndicesNew.splice(remainingClassIndicesNew.indexOf(classIndex), 1); } } } catch (err) { _didIteratorError = true; _iteratorError = err; } finally { try { if (!_iteratorNormalCompletion && _iterator.return) { _iterator.return(); } } finally { if (_didIteratorError) { throw _iteratorError; } } } remainingClassIndices = remainingClassIndicesNew; // Emit event the outside can hook into this.emit('iterationCompleted'); epoch += 1; } // Emit event the outside can hook into this.emit('converged'); } /** * @see {Classifier#predict} */ }, { key: 'predict', value: function predict(X) { // Get predictions from all classifiers for all data points by predicting all data points with // each classifier (getting an array of predictions for each classifier) and transposing var datapointsPredictions = LinAlg.transpose(this.classifiers.map(function (classifier) { return classifier.classifier.predict(X, { output: 'normalized' }); })); // Form final prediction by taking index of maximum normalized classifier output return datapointsPredictions.map(function (x) { return Arrays.argMax(x); }); } }]); return OneVsAllClassifier; }(Classifier);