UNPKG

@jsmlt/jsmlt

Version:

JavaScript Machine Learning

102 lines (82 loc) 4.21 kB
'use strict'; Object.defineProperty(exports, "__esModule", { value: true }); var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; }; /* eslint import/prefer-default-export: "off" */ // Local imports exports.trainTestSplit = trainTestSplit; var _arrays = require('../util/arrays'); var Arrays = _interopRequireWildcard(_arrays); function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } } function _toConsumableArray(arr) { if (Array.isArray(arr)) { for (var i = 0, arr2 = Array(arr.length); i < arr.length; i++) { arr2[i] = arr[i]; } return arr2; } else { return Array.from(arr); } } /** * Split a dataset into a training and a test set. * * @example <caption>Example with n=5 datapoints and d=2 features per sample</caption> * // n x d array of features * var X = [[0, 0], [0.5, 0.2], [0.3, 2.5], [0.8, 0.9], [0.7, 0.2]]; * * // n-dimensional array of labels * var y = [1, 0, 0, 1, 1]; // n-dimensional array of labels * * // Split into training and test set * var [X_train, y_train, X_test, y_test] = trainTestSplit([X, y], {trainSize: 0.8}); * * // Now, X_train and y_train will contain the features and labels of the training set, * // respectively, and X_test and y_test will contain the features and labels of the test set. * * // Depending on the random seed, the result might be the following * X_train: [[0, 0], [0.5, 0.2], [0.3, 2.5], [0.7, 0.2]] * y_train: [1, 0, 0, 1] * X_test: [[0.8, 0.9]] * y_test: [1] * * @param {Array.<Array.<mixed>>} input - List of input arrays. The input arrays should have the * same length (i.e., they should have the same first dimension size) * @param {Object} optionsUser - User-defined options. See method implementation for details * @param {number} [optionsUser.trainSize = 0.8] - Size of the training set. If int, this exact * number of training samples is used. If float, the total number of elements times the float * number is used as the number of training elements * @return {Array} List of output arrays. The number of elements is 2 times the number of input * elements. For each input element, a pair of output elements is returned. */ function trainTestSplit(input) { var optionsUser = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; // Options var optionsDefault = { trainSize: 0.8 }; var options = _extends({}, optionsDefault, optionsUser); // Total number of elements var numElements = input[0].length; // Check whether all input data sets have the same size if (!input.every(function (x) { return x.length === input[0].length; })) { throw new Error('All input arrays should have the same length (i.e., the size of their\n first dimensions should be the same'); } // Generate list of all possible array indices var indices = [].concat(_toConsumableArray(Array(numElements).keys())); // Number of training elements var numTrainElements = Math.round(numElements * options.trainSize); // Take a random sample from the list of possible indices, which are then used as the indices // of the elements to use for the training data var trainIndices = Arrays.sample(indices, numTrainElements); // Create resulting training and test sets var trainArrays = input.map(function (trainArray) { return trainArray.filter(function (x, i) { return trainIndices.includes(i); }).map(function (x) { return Array.isArray(x) ? x.slice() : x; }); }); var testArrays = input.map(function (testArray) { return testArray.filter(function (x, i) { return !trainIndices.includes(i); }).map(function (x) { return Array.isArray(x) ? x.slice() : x; }); }); // Return train and test sets return [].concat(_toConsumableArray(trainArrays), _toConsumableArray(testArrays)); }