federer
Version:
Experiments in asynchronous federated learning and decentralized learning
198 lines • 7.29 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.DataSubset = exports.Dataset = void 0;
const tslib_1 = require("tslib");
const path = tslib_1.__importStar(require("path"));
const assert = require("assert");
const tfjs_npy_node_1 = require("tfjs-npy-node");
const path_or_url_1 = require("./path-or-url");
/**
* A dataset is a collection of data. This data is split into multiple
* {@link DataSubset}s. For now, only one split is possible: splitting into a
* training subset and a test subset.
*
* @typeParam IR The rank of the tensors holding items (IR means "Item Rank")
* @typeParam LR The rank of the tensors holding labels (LR means "Label Rank")
*/
class Dataset {
constructor(data) {
this.data = data;
}
/** Get the training subset */
get train() {
return this.data.train;
}
/** Get the test subset */
get test() {
return this.data.test;
}
/** Equivalent to `tf.dispose`. */
dispose() {
this.train.dispose();
this.test.dispose();
}
/**
* Counts the number of datapoints present in the {@link Dataset}. Each
* item and its associated label constitutes a single datapoint.
*
* @returns The number of datapoints in the {@link Dataset}
*/
countNumberDatapoints() {
return (this.train.countNumberDatapoints() + this.test.countNumberDatapoints());
}
concat(that) {
return new Dataset({
train: this.train.concat(that.train),
test: this.test.concat(that.test),
});
}
/**
* Save the dataset to files.
*
* @param directory Directory to save to
* @param filenamePrefix Prefix of the dataset filenames
* @return A promise of the filepaths that the dataset has been saved to
*/
async save(directory, filenamePrefix) {
const [train, test] = await Promise.all([
this.train.save(directory, `${filenamePrefix}-train`),
this.test.save(directory, `${filenamePrefix}-test`),
]);
return { train, test };
}
/**
* Load a dataset from files.
*
* @param filePathsOrUrls Filepaths or URLs to load the dataset from
* @param expectedRanks Expected ranks of the items and labels in the dataset.
* If not provided, the ranks of the loaded data will not be checked.
* @return A promise of the loaded dataset, with the expected ranks.
*/
static async load(filePathsOrUrls, expectedRanks) {
const [train, test] = await Promise.all([
DataSubset.load(filePathsOrUrls.train, expectedRanks),
DataSubset.load(filePathsOrUrls.test, expectedRanks),
]);
return new Dataset({ train, test });
}
/** Cast to a {@link Dataset} with given item and label ranks. */
withRanks(ranks) {
return new Dataset({
train: this.train.withRanks(ranks),
test: this.test.withRanks(ranks),
});
}
}
exports.Dataset = Dataset;
/**
* A data subset is a subset of the datapoints of a {@link Dataset}.
*
* The {@link https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets |
* standard terminology} for this is "dataset" or "set", but this is confusing
* because the full dataset is also called a "dataset". We therefore call this
* a "data subset", as it is a subset of the full dataset.
*
* @typeParam IR The rank of the tensors holding items (IR means "Item Rank")
* @typeParam LR The rank of the tensors holding labels (LR means "Label Rank")
*/
class DataSubset {
constructor(data) {
this.data = data;
}
get items() {
return this.data.items;
}
get labels() {
return this.data.labels;
}
/** Equivalent to `tf.dispose`. */
dispose() {
this.items.dispose();
this.labels.dispose();
}
concat(that) {
return new DataSubset({
items: this.items.concat(that.items),
labels: this.labels.concat(that.labels),
});
}
/**Splits into two with split ratio */
split(split) {
assert(split >= 0 && split <= 1, Error("Split ration should be in range of 0 to 1"));
const numberOfSamples = Math.ceil(split * this.countNumberDatapoints());
return [
new DataSubset({
items: this.items.slice([0], [numberOfSamples]),
labels: this.labels.slice([0], [numberOfSamples]),
}),
new DataSubset({
items: this.items.slice([numberOfSamples]),
labels: this.labels.slice([numberOfSamples]),
}),
];
}
/**
* Counts the number of datapoints present in the {@link DataSubset}. Each
* item and its associated label constitutes a single datapoint.
*
* @returns The number of datapoints in the {@link DataSubset}
*/
countNumberDatapoints() {
return this.labels.shape[0];
}
/**
* Save the data subset to files.
*
* @param directory Directory to save to
* @param filenamePrefix Prefix of the data subset filenames
* @return A promise of the filepaths that the data subset has been saved to
*/
async save(directory, filenamePrefix) {
const itemsFile = path.join(directory, `${filenamePrefix}-items.npy`);
const labelsFile = path.join(directory, `${filenamePrefix}-labels.npy`);
await Promise.all([
tfjs_npy_node_1.npy.save(itemsFile, this.data.items),
tfjs_npy_node_1.npy.save(labelsFile, this.data.labels),
]);
return { items: itemsFile, labels: labelsFile };
}
/**
* Load a data subset from files.
*
* @param filepathOrURL Filepaths or URLs to load the data subset from
* @param expectedRanks Expected ranks of the items and labels in the dataset.
* If not provided, the ranks of the loaded data will not be checked.
* @return A promise of the loaded data subset, with the expected ranks.
*/
static async load(filepathOrURL, expectedRanks) {
const [items, labels] = await Promise.all([
path_or_url_1.PathOrURL.read(filepathOrURL.items).then((buffer) => tfjs_npy_node_1.npy.parse(buffer)),
path_or_url_1.PathOrURL.read(filepathOrURL.labels).then((buffer) => tfjs_npy_node_1.npy.parse(buffer)),
]);
if (expectedRanks?.items !== undefined) {
assert.strictEqual("R" + items.rankType, expectedRanks.items);
}
if (expectedRanks?.labels !== undefined) {
assert.strictEqual("R" + labels.rankType, expectedRanks.labels);
}
return new DataSubset({
items: items,
labels: labels,
});
}
/** Cast to a {@link DataSubset} with given item and label ranks. */
withRanks(ranks) {
if (ranks.items !== undefined) {
assert.strictEqual("R" + this.items.rankType, ranks.items);
}
if (ranks.labels !== undefined) {
assert.strictEqual("R" + this.labels.rankType, ranks.labels);
}
return new DataSubset({
items: this.items,
labels: this.labels,
});
}
}
exports.DataSubset = DataSubset;
//# sourceMappingURL=dataset.js.map