UNPKG

@jsmlt/jsmlt

Version:

JavaScript Machine Learning

102 lines (79 loc) 5.89 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.trainTestSplit = trainTestSplit; var Random = _interopRequireWildcard(require("../random")); function _getRequireWildcardCache() { if (typeof WeakMap !== "function") return null; var cache = new WeakMap(); _getRequireWildcardCache = function _getRequireWildcardCache() { return cache; }; return cache; } function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } var cache = _getRequireWildcardCache(); if (cache && cache.has(obj)) { return cache.get(obj); } var newObj = {}; if (obj != null) { var hasPropertyDescriptor = Object.defineProperty && Object.getOwnPropertyDescriptor; for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { var desc = hasPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : null; if (desc && (desc.get || desc.set)) { Object.defineProperty(newObj, key, desc); } else { newObj[key] = obj[key]; } } } } newObj["default"] = obj; if (cache) { cache.set(obj, newObj); } return newObj; } function _toConsumableArray(arr) { return _arrayWithoutHoles(arr) || _iterableToArray(arr) || _nonIterableSpread(); } function _nonIterableSpread() { throw new TypeError("Invalid attempt to spread non-iterable instance"); } function _iterableToArray(iter) { if (Symbol.iterator in Object(iter) || Object.prototype.toString.call(iter) === "[object Arguments]") return Array.from(iter); } function _arrayWithoutHoles(arr) { if (Array.isArray(arr)) { for (var i = 0, arr2 = new Array(arr.length); i < arr.length; i++) { arr2[i] = arr[i]; } return arr2; } } function ownKeys(object, enumerableOnly) { var keys = Object.keys(object); if (Object.getOwnPropertySymbols) { var symbols = Object.getOwnPropertySymbols(object); if (enumerableOnly) symbols = symbols.filter(function (sym) { return Object.getOwnPropertyDescriptor(object, sym).enumerable; }); keys.push.apply(keys, symbols); } return keys; } function _objectSpread(target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i] != null ? arguments[i] : {}; if (i % 2) { ownKeys(source, true).forEach(function (key) { _defineProperty(target, key, source[key]); }); } else if (Object.getOwnPropertyDescriptors) { Object.defineProperties(target, Object.getOwnPropertyDescriptors(source)); } else { ownKeys(source).forEach(function (key) { Object.defineProperty(target, key, Object.getOwnPropertyDescriptor(source, key)); }); } } return target; } function _defineProperty(obj, key, value) { if (key in obj) { Object.defineProperty(obj, key, { value: value, enumerable: true, configurable: true, writable: true }); } else { obj[key] = value; } return obj; } /** * Split a dataset into a training and a test set. * * @example <caption>Example with n=5 datapoints and d=2 features per sample</caption> * // n x d array of features * var X = [[0, 0], [0.5, 0.2], [0.3, 2.5], [0.8, 0.9], [0.7, 0.2]]; * * // n-dimensional array of labels * var y = [1, 0, 0, 1, 1]; // n-dimensional array of labels * * // Split into training and test set * var [X_train, y_train, X_test, y_test] = trainTestSplit([X, y], {trainSize: 0.8}); * * // Now, X_train and y_train will contain the features and labels of the training set, * // respectively, and X_test and y_test will contain the features and labels of the test set. * * // Depending on the random seed, the result might be the following * X_train: [[0, 0], [0.5, 0.2], [0.3, 2.5], [0.7, 0.2]] * y_train: [1, 0, 0, 1] * X_test: [[0.8, 0.9]] * y_test: [1] * * @param {Array.<Array.<mixed>>} input - List of input arrays. The input arrays should have the * same length (i.e., they should have the same first dimension size) * @param {Object} optionsUser - User-defined options. See method implementation for details * @param {number} [optionsUser.trainSize = 0.8] - Size of the training set. If int, this exact * number of training samples is used. If float, the total number of elements times the float * number is used as the number of training elements * @return {Array} List of output arrays. The number of elements is 2 times the number of input * elements. For each input element, a pair of output elements is returned. */ function trainTestSplit(input) { var optionsUser = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; // Options var optionsDefault = { trainSize: 0.8 }; var options = _objectSpread({}, optionsDefault, {}, optionsUser); // Total number of elements var numElements = input[0].length; // Check whether all input data sets have the same size if (!input.every(function (x) { return x.length === input[0].length; })) { throw new Error("All input arrays should have the same length (i.e., the size of their\n first dimensions should be the same"); } // Generate list of all possible array indices var indices = _toConsumableArray(Array(numElements).keys()); // Number of training elements var numTrainElements = Math.round(numElements * options.trainSize); // Take a random sample from the list of possible indices, which are then used as the indices // of the elements to use for the training data var trainIndices = Random.sample(indices, numTrainElements, false); // Create resulting training and test sets var trainArrays = input.map(function (trainArray) { return trainArray.filter(function (x, i) { return trainIndices.includes(i); }).map(function (x) { return Array.isArray(x) ? x.slice() : x; }); }); var testArrays = input.map(function (testArray) { return testArray.filter(function (x, i) { return !trainIndices.includes(i); }).map(function (x) { return Array.isArray(x) ? x.slice() : x; }); }); // Return train and test sets return [].concat(_toConsumableArray(trainArrays), _toConsumableArray(testArrays)); }