@jsmlt/jsmlt
Version:
JavaScript Machine Learning
535 lines (432 loc) • 20.9 kB
JavaScript
'use strict';
Object.defineProperty(exports, "__esModule", {
value: true
});
exports.BinarySVM = undefined;
var _slicedToArray = function () { function sliceIterator(arr, i) { var _arr = []; var _n = true; var _d = false; var _e = undefined; try { for (var _i = arr[Symbol.iterator](), _s; !(_n = (_s = _i.next()).done); _n = true) { _arr.push(_s.value); if (i && _arr.length === i) break; } } catch (err) { _d = true; _e = err; } finally { try { if (!_n && _i["return"]) _i["return"](); } finally { if (_d) throw _e; } } return _arr; } return function (arr, i) { if (Array.isArray(arr)) { return arr; } else if (Symbol.iterator in Object(arr)) { return sliceIterator(arr, i); } else { throw new TypeError("Invalid attempt to destructure non-iterable instance"); } }; }();
var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; };
var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
var _base = require('../base');
var _linalg = require('../../math/linalg');
var LinAlg = _interopRequireWildcard(_linalg);
var _arrays = require('../../util/arrays');
var Arrays = _interopRequireWildcard(_arrays);
var _random = require('../../util/random');
var Random = _interopRequireWildcard(_random);
var _linear = require('../../kernel/linear');
var _linear2 = _interopRequireDefault(_linear);
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } }
function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; }
function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; } // Internal dependencies
/**
* SVM learner for binary classification problem.
*/
var BinarySVM = exports.BinarySVM = function (_Classifier) {
_inherits(BinarySVM, _Classifier);
/**
* Constructor. Initialize class members and store user-defined options.
*
* @param {Object} [optionsUser] - User-defined options for SVM
* @param {number} [optionsUser.C = 100] - Regularization (i.e. penalty for slack variables)
* @param {Object} [optionsUser.kernel] - Kernel. Defaults to the linear kernel
* @param {number} [optionsUser.convergenceNumPasses = 10] - Number of passes without alphas
* changing to treat the algorithm as converged
* @param {number} [optionsUser.numericalTolerance = 1e-6] - Numerical tolerance for a
* value in the to be equal to another SMO algorithm to be equal to another value
* @param {boolean} [optionsUser.useKernelCache = true] - Whether to cache calculated kernel
* values for training sample pairs. Enabling this option (which is the default) generally
* improves the performance in terms of speed at the cost of memory
*/
function BinarySVM() {
var optionsUser = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};
_classCallCheck(this, BinarySVM);
// Parse options
var _this = _possibleConstructorReturn(this, (BinarySVM.__proto__ || Object.getPrototypeOf(BinarySVM)).call(this));
var optionsDefault = {
C: 100.0,
kernel: null,
convergenceNumPasses: 10,
numericalTolerance: 1e-6,
useKernelCache: true
};
var options = _extends({}, optionsDefault, optionsUser);
// Set options
_this.C = options.C;
_this.kernel = options.kernel === null ? new _linear2.default() : options.kernel;
_this.convergenceNumPasses = options.convergenceNumPasses;
_this.numericalTolerance = options.numericalTolerance;
_this.useKernelCache = options.useKernelCache;
// Set properties
_this.isTraining = false;
return _this;
}
/**
* Get the signed value of the class index. Returns -1 for class index 0, 1 for class index 1.
*
* @param {number} classIndex - Class index
* @return {number} Sign corresponding to class index
*/
_createClass(BinarySVM, [{
key: 'getClassIndexSign',
value: function getClassIndexSign(classIndex) {
return classIndex * 2 - 1;
}
/**
* Get the class index corresponding to a sign.
*
* @param {number} sign - Sign
* @return {number} Class index corresponding to sign
*/
}, {
key: 'getSignClassIndex',
value: function getSignClassIndex(sign) {
return (sign + 1) / 2;
}
/**
* @see {@link Classifier#train}
*/
}, {
key: 'train',
value: function train(X, y) {
var _this2 = this;
// Mark that the SVM is currently in the training procedure
this.isTraining = true;
// Number of training samples
var numSamples = X.length;
// Alphas (Lagrange multipliers)
this.alphas = LinAlg.zeroVector(numSamples);
// Bias term
this.b = 0.0;
// Kernel cache
this.kernelCache = LinAlg.full([numSamples, numSamples], 0.0);
this.kernelCacheStatus = LinAlg.full([numSamples, numSamples], false);
// Number of passes of the algorithm without any alphas changing
var numPasses = 0;
// Shorthand notation for features and labels
this.training = { X: X, y: y };
var ySigns = y.map(function (x) {
return _this2.getClassIndexSign(x);
});
while (numPasses < this.convergenceNumPasses) {
var alphasChanged = 0;
// Loop over all training samples
for (var i = 0; i < numSamples; i += 1) {
// Calculate offset to the 1-margin of sample i
var ei = this.sampleMargin(i) - ySigns[i];
// Check whether the KKT constraints were violated
if (ySigns[i] * ei < -this.numericalTolerance && this.alphas[i] < this.C || ySigns[i] * ei > this.numericalTolerance && this.alphas[i] > 0) {
/* Now, we need to update \alpha_i as it violates the KKT constraints */
// Thus, we pick a random \alpha_j such that j does not equal i
var j = Random.randint(0, numSamples - 1);
if (j >= i) j += 1;
// Calculate offset to the 1-margin of sample j
var ej = this.sampleMargin(j) - ySigns[j];
// Calculate lower and upper bounds for \alpha_j
var _calculateAlphaBounds = this.calculateAlphaBounds(i, j),
_calculateAlphaBounds2 = _slicedToArray(_calculateAlphaBounds, 2),
boundL = _calculateAlphaBounds2[0],
boundH = _calculateAlphaBounds2[1];
if (Math.abs(boundH - boundL) < this.numericalTolerance) {
// Difference between bounds is practically zero, so there's not much to optimize.
// Continue to next sample.
continue;
}
// Calculate second derivative of cost function from Lagrange dual problem. Note
// that a_i = (g - a_j * y_j) / y_i, where g is the negative sum of all a_k * y_k where
// k is not equal to i or j
var Kij = this.applyKernel(i, j);
var Kii = this.applyKernel(i, i);
var Kjj = this.applyKernel(j, j);
var eta = 2 * Kij - Kii - Kjj;
if (eta >= 0) {
continue;
}
// Compute new \alpha_j
var oldAlphaJ = this.alphas[j];
var newAlphaJ = oldAlphaJ - ySigns[j] * (ei - ej) / eta;
newAlphaJ = Math.min(newAlphaJ, boundH);
newAlphaJ = Math.max(newAlphaJ, boundL);
// Don't update if the difference is too small
if (Math.abs(newAlphaJ - oldAlphaJ) < this.numericalTolerance) {
continue;
}
// Compute new \alpha_i
var oldAlphaI = this.alphas[i];
var newAlphaI = oldAlphaI + ySigns[i] * ySigns[j] * (oldAlphaJ - newAlphaJ);
// Update \alpha_j and \alpha_i
this.alphas[j] = newAlphaJ;
this.alphas[i] = newAlphaI;
// Update the bias term, interpolating between the bias terms for \alpha_i and \alpha_j
var b1 = this.b - ei - ySigns[i] * (newAlphaI - oldAlphaI) * Kii - ySigns[j] * (newAlphaJ - oldAlphaJ) * Kij;
var b2 = this.b - ej - ySigns[i] * (newAlphaI - oldAlphaI) * Kij - ySigns[j] * (newAlphaJ - oldAlphaJ) * Kjj;
if (newAlphaJ > 0 && newAlphaJ < this.C) {
this.b = b2;
} else if (newAlphaI > 0 && newAlphaI < this.C) {
this.b = b1;
} else {
this.b = (b1 + b2) / 2;
}
alphasChanged += 1;
}
}
if (alphasChanged === 0) {
numPasses += 1;
} else {
numPasses = 0;
}
}
// Store indices of support vectors (where alpha > 0, or, in this case, where alpha is greater
// than some numerical tolerance)
this.supportVectors = Arrays.zipWithIndex(this.alphas).filter(function (x) {
return x[0] > 1e-6;
}).map(function (x) {
return x[1];
});
// Mark that training has completed
this.isTraining = false;
}
/**
* Calculate the margin (distance to the decision boundary) for a single sample.
*
* @param {Array.<number>|number} sample - Sample features array or training sample index
* @return {number} Distance to decision boundary
*/
}, {
key: 'sampleMargin',
value: function sampleMargin(sample) {
var rval = this.b;
if (this.isTraining) {
// If we're in the training phase, we have to loop over all elements
for (var i = 0; i < this.training.X.length; i += 1) {
var k = this.applyKernel(sample, i);
rval += this.getClassIndexSign(this.training.y[i]) * this.alphas[i] * k;
}
} else {
// If training is done, we only loop over the support vectors
var _iteratorNormalCompletion = true;
var _didIteratorError = false;
var _iteratorError = undefined;
try {
for (var _iterator = this.supportVectors[Symbol.iterator](), _step; !(_iteratorNormalCompletion = (_step = _iterator.next()).done); _iteratorNormalCompletion = true) {
var sv = _step.value;
var _k = this.applyKernel(sample, this.training.X[sv]);
rval += this.getClassIndexSign(this.training.y[sv]) * this.alphas[sv] * _k;
}
} catch (err) {
_didIteratorError = true;
_iteratorError = err;
} finally {
try {
if (!_iteratorNormalCompletion && _iterator.return) {
_iterator.return();
}
} finally {
if (_didIteratorError) {
throw _iteratorError;
}
}
}
}
return rval;
}
/**
* Apply the kernel to two data points. Accepts both feature arrays and training data point
* indices for x and y. If x and y are integers, attempts to fetch the kernel result for the
* corresponding training data points from cache, and computes and stores the result in cache if
* it isn't found
*
* @param {Array.<number>|number} x - Feature vector or data point index for first data point.
* Arrays are treated as feature vectors, integers as training data point indices
* @param {Array.<number>|number} y - Feature vector or data point index for second data point.
* Arrays are treated as feature vectors, integers as training data point indices
* @return {number} Kernel result
*/
}, {
key: 'applyKernel',
value: function applyKernel(x, y) {
var fromCache = this.useKernelCache && typeof x === 'number' && typeof y === 'number';
if (fromCache && this.kernelCacheStatus[x][y] === true) {
return this.kernelCache[x][y];
}
var xf = typeof x === 'number' ? this.training.X[x] : x;
var yf = typeof y === 'number' ? this.training.X[y] : y;
var result = this.kernel.apply(xf, yf);
if (fromCache) {
this.kernelCache[x][y] = result;
this.kernelCacheStatus[x][y] = true;
}
return result;
}
/**
* Calculate the bounds on \alpha_j to make sure it can be clipped to the [0,C] box and that it
* can be chosen to satisfy the linear equality constraint stemming from the fact that the sum of
* all products y_i * a_i should equal 0.
*
* @param {number} i Index of \alpha_i
* @param {number} j Index of \alpha_j
* @return {Array.<number>} Two-dimensional array containing the lower and upper bound
*/
}, {
key: 'calculateAlphaBounds',
value: function calculateAlphaBounds(i, j) {
var boundL = void 0;
var boundH = void 0;
if (this.training.y[i] === this.training.y[j]) {
// The alphas lie on a line with slope -1
boundL = this.alphas[j] - (this.C - this.alphas[i]);
boundH = this.alphas[j] + this.alphas[i];
} else {
// The alphas lie on a line with slope 1
boundL = this.alphas[j] - this.alphas[i];
boundH = this.alphas[j] + (this.C - this.alphas[i]);
}
boundL = Math.max(0, boundL);
boundH = Math.min(this.C, boundH);
return [boundL, boundH];
}
/**
* Make a prediction for a data set.
*
* @param {Array.<Array.<mixed>>} features - Features for each data point. Each array element
* should be an array containing the features of the data point
* @param {Object} [optionsUser] - Options for prediction
* @param {string} [optionsUser.output = 'classLabels'] - Output for predictions. Either
* "classLabels" (default, output predicted class label), "raw" or "normalized" (both output
* margin (distance to decision boundary) for each sample)
* @return {Array.<mixed>} Predictions. Output dependent on options.output, defaults to class
* labels
*/
}, {
key: 'predict',
value: function predict(features) {
var _this3 = this;
var optionsUser = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
// Options
var optionsDefault = {
output: 'classLabels' // 'classLabels', 'normalized' or 'raw'
};
var options = _extends({}, optionsDefault, optionsUser);
return features.map(function (x) {
var output = _this3.sampleMargin(x);
// Store prediction
if (options.output === 'raw' || options.output === 'normalized') {
// Raw output: do nothing
} else {
// Class label output
output = _this3.getSignClassIndex(output > 0 ? 1 : -1);
}
return output;
});
}
}]);
return BinarySVM;
}(_base.Classifier);
/**
* Support Vector Machine (SVM) classification model for 2 or more classes. The model is a
* one-vs-all classifier and uses the {@link BinarySVM} classifier as its base model. For training
* individual models, a simplified version of John Platt's
* [SMO algorithm](@link https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tr-98-14.pdf)
* is used.
*
* ## Support Vector Machines (SVM)
* Support Vector Machines train a classifier by finding the decision boundary between the classes
* that maximizes the margin between the boundary and the data points on either side of it. It is a
* very intuitive approach to classification. The soft-margin SVM is a modification of this approach
* where samples are allowed to be misclasiffied, but at some cost. Furthermore, SVMs can implement
* the "kernel trick", where an implicit feature transformation is applied.
*
* This is all implemented in the SVM implementation in JSMLT. For more information about Support
* Vector Machines, you can start by checking out the [Wikipedia](https://en.wikipedia.org/wiki/Support_vector_machine)
* page on SVMs.
*
* ## Examples
* **Example 1:** Training a multiclass SVM on a well-known three-class dataset, the
* [Iris dataset](https://github.com/jsmlt/datasets-repository/tree/master/iris#readme).
*
* @example <caption>Example 1. SVM training on a multiclass classification task.</caption>
* // Import JSMLT
* var jsmlt = require('@jsmlt/jsmlt');
*
* // Load the iris dataset. When loading is completed, process the data and run the classifier
* jsmlt.Datasets.loadIris((X, y_raw) => {
* // Encode the labels (which are strings) into integers
* var labelencoder = new jsmlt.Preprocessing.LabelEncoder();
* var y = labelencoder.encode(y_raw);
*
* // Split the data into a training set and a test set
* [X_train, y_train, X_test, y_test] = jsmlt.ModelSelection.trainTestSplit([X, y]);
*
* // Create and train classifier
* var clf = new jsmlt.Supervised.SVM.SVM({
* kernel: new jsmlt.Kernel.Gaussian(1),
* });
* clf.train(X_train, y_train);
*
* // Make predictions on test data
* var predictions = clf.predict(X_test);
*
* // Evaluate and output the classifier's accuracy
* var accuracy = jsmlt.Validation.Metrics.accuracy(predictions, y_test);
* console.log(`Accuracy: ${accuracy}`);
* });
*
* @see {@link BinarySVM}
*/
var SVM = function (_OneVsAllClassifier) {
_inherits(SVM, _OneVsAllClassifier);
/**
* Constructor. Initialize class members and store user-defined options.
*
* @see {@link BinarySVM#constructor}
*
* @param {Object} optionsUser User-defined options for SVM. Options are passed to created
* BinarySVM objects. See BinarySVM.constructor() for more details
* @param {Object} [optionsUser] - User-defined options for SVM
* @param {number} [optionsUser.C = 100] - Regularization (i.e. penalty for slack variables)
* @param {Object} [optionsUser.kernel] - Kernel. Defaults to the linear kernel
* @param {number} [optionsUser.convergenceNumPasses = 10] - Number of passes without alphas
* changing to treat the algorithm as converged
* @param {number} [optionsUser.numericalTolerance = 1e-6] - Numerical tolerance for a
* value in the to be equal to another SMO algorithm to be equal to another value
* @param {boolean} [optionsUser.useKernelCache = true] - Whether to cache calculated kernel
* values for training sample pairs. Enabling this option (which is the default) generally
* improves the performance in terms of speed at the cost of memory
*/
function SVM() {
var optionsUser = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};
_classCallCheck(this, SVM);
/**
* Number of errors per iteration. Only used if accuracy tracking is enabled
*
* @type {Array.<mixed>}
*/
var _this4 = _possibleConstructorReturn(this, (SVM.__proto__ || Object.getPrototypeOf(SVM)).call(this));
_this4.numErrors = null;
// Set options
_this4.optionsUser = optionsUser;
return _this4;
}
/**
* @see {@link OneVsAll#createClassifier}
*/
_createClass(SVM, [{
key: 'createClassifier',
value: function createClassifier() {
return new BinarySVM(this.optionsUser);
}
/**
* @see {@link Estimator#train}
*/
}, {
key: 'train',
value: function train(X, y) {
this.training = { X: X, y: y };
this.createClassifiers(y);
this.trainBatch(X, y);
}
}]);
return SVM;
}(_base.OneVsAllClassifier);
exports.default = SVM;