UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

198 lines 7.29 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.DataSubset = exports.Dataset = void 0; const tslib_1 = require("tslib"); const path = tslib_1.__importStar(require("path")); const assert = require("assert"); const tfjs_npy_node_1 = require("tfjs-npy-node"); const path_or_url_1 = require("./path-or-url"); /** * A dataset is a collection of data. This data is split into multiple * {@link DataSubset}s. For now, only one split is possible: splitting into a * training subset and a test subset. * * @typeParam IR The rank of the tensors holding items (IR means "Item Rank") * @typeParam LR The rank of the tensors holding labels (LR means "Label Rank") */ class Dataset { constructor(data) { this.data = data; } /** Get the training subset */ get train() { return this.data.train; } /** Get the test subset */ get test() { return this.data.test; } /** Equivalent to `tf.dispose`. */ dispose() { this.train.dispose(); this.test.dispose(); } /** * Counts the number of datapoints present in the {@link Dataset}. Each * item and its associated label constitutes a single datapoint. * * @returns The number of datapoints in the {@link Dataset} */ countNumberDatapoints() { return (this.train.countNumberDatapoints() + this.test.countNumberDatapoints()); } concat(that) { return new Dataset({ train: this.train.concat(that.train), test: this.test.concat(that.test), }); } /** * Save the dataset to files. * * @param directory Directory to save to * @param filenamePrefix Prefix of the dataset filenames * @return A promise of the filepaths that the dataset has been saved to */ async save(directory, filenamePrefix) { const [train, test] = await Promise.all([ this.train.save(directory, `${filenamePrefix}-train`), this.test.save(directory, `${filenamePrefix}-test`), ]); return { train, test }; } /** * Load a dataset from files. * * @param filePathsOrUrls Filepaths or URLs to load the dataset from * @param expectedRanks Expected ranks of the items and labels in the dataset. * If not provided, the ranks of the loaded data will not be checked. * @return A promise of the loaded dataset, with the expected ranks. */ static async load(filePathsOrUrls, expectedRanks) { const [train, test] = await Promise.all([ DataSubset.load(filePathsOrUrls.train, expectedRanks), DataSubset.load(filePathsOrUrls.test, expectedRanks), ]); return new Dataset({ train, test }); } /** Cast to a {@link Dataset} with given item and label ranks. */ withRanks(ranks) { return new Dataset({ train: this.train.withRanks(ranks), test: this.test.withRanks(ranks), }); } } exports.Dataset = Dataset; /** * A data subset is a subset of the datapoints of a {@link Dataset}. * * The {@link https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets | * standard terminology} for this is "dataset" or "set", but this is confusing * because the full dataset is also called a "dataset". We therefore call this * a "data subset", as it is a subset of the full dataset. * * @typeParam IR The rank of the tensors holding items (IR means "Item Rank") * @typeParam LR The rank of the tensors holding labels (LR means "Label Rank") */ class DataSubset { constructor(data) { this.data = data; } get items() { return this.data.items; } get labels() { return this.data.labels; } /** Equivalent to `tf.dispose`. */ dispose() { this.items.dispose(); this.labels.dispose(); } concat(that) { return new DataSubset({ items: this.items.concat(that.items), labels: this.labels.concat(that.labels), }); } /**Splits into two with split ratio */ split(split) { assert(split >= 0 && split <= 1, Error("Split ration should be in range of 0 to 1")); const numberOfSamples = Math.ceil(split * this.countNumberDatapoints()); return [ new DataSubset({ items: this.items.slice([0], [numberOfSamples]), labels: this.labels.slice([0], [numberOfSamples]), }), new DataSubset({ items: this.items.slice([numberOfSamples]), labels: this.labels.slice([numberOfSamples]), }), ]; } /** * Counts the number of datapoints present in the {@link DataSubset}. Each * item and its associated label constitutes a single datapoint. * * @returns The number of datapoints in the {@link DataSubset} */ countNumberDatapoints() { return this.labels.shape[0]; } /** * Save the data subset to files. * * @param directory Directory to save to * @param filenamePrefix Prefix of the data subset filenames * @return A promise of the filepaths that the data subset has been saved to */ async save(directory, filenamePrefix) { const itemsFile = path.join(directory, `${filenamePrefix}-items.npy`); const labelsFile = path.join(directory, `${filenamePrefix}-labels.npy`); await Promise.all([ tfjs_npy_node_1.npy.save(itemsFile, this.data.items), tfjs_npy_node_1.npy.save(labelsFile, this.data.labels), ]); return { items: itemsFile, labels: labelsFile }; } /** * Load a data subset from files. * * @param filepathOrURL Filepaths or URLs to load the data subset from * @param expectedRanks Expected ranks of the items and labels in the dataset. * If not provided, the ranks of the loaded data will not be checked. * @return A promise of the loaded data subset, with the expected ranks. */ static async load(filepathOrURL, expectedRanks) { const [items, labels] = await Promise.all([ path_or_url_1.PathOrURL.read(filepathOrURL.items).then((buffer) => tfjs_npy_node_1.npy.parse(buffer)), path_or_url_1.PathOrURL.read(filepathOrURL.labels).then((buffer) => tfjs_npy_node_1.npy.parse(buffer)), ]); if (expectedRanks?.items !== undefined) { assert.strictEqual("R" + items.rankType, expectedRanks.items); } if (expectedRanks?.labels !== undefined) { assert.strictEqual("R" + labels.rankType, expectedRanks.labels); } return new DataSubset({ items: items, labels: labels, }); } /** Cast to a {@link DataSubset} with given item and label ranks. */ withRanks(ranks) { if (ranks.items !== undefined) { assert.strictEqual("R" + this.items.rankType, ranks.items); } if (ranks.labels !== undefined) { assert.strictEqual("R" + this.labels.rankType, ranks.labels); } return new DataSubset({ items: this.items, labels: this.labels, }); } } exports.DataSubset = DataSubset; //# sourceMappingURL=dataset.js.map