UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

55 lines 2.98 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.preprocess = void 0; const tfjs_node_1 = require("@tensorflow/tfjs-node"); const common_1 = require("../../../common"); const coordinator_1 = require("../../../coordinator"); const read_binary_1 = require("./data/read-binary"); /** * Preprocess the raw MNIST data into client shards and a test set. * * @param options Preprocessing options * @returns The results of the preprocessing, with all data saved to disk */ async function preprocess(options) { const pipeline = new coordinator_1.PreprocessPipeline(options.dataset, uniquePipelineName(options), options.numberLabelClasses, createPipelineFunctions(options)); const processed = await pipeline.run(options.allowReadFromCache); const testSet = processed.testSet.withRanks({ labels: tfjs_node_1.Rank.R2, items: options.modelName === "CNN" ? tfjs_node_1.Rank.R4 : tfjs_node_1.Rank.R2, }); return { ...processed, testSet }; } exports.preprocess = preprocess; function uniquePipelineName(options) { return (`${options.modelName}-${options.numberLabelClasses}-labels-split-into-` + `${options.numberClients}-shards-of-${options.numberDigitBatchesPerClient}-batches`); } function createPipelineFunctions(options) { // If we have less than 10 classes, we need to filter out labels that aren't // part of the experiment; if we have exactly 10 classes, we don't do any // filtering const filter = options.numberLabelClasses !== 10 ? coordinator_1.filterByLabels(common_1.range(0, options.numberLabelClasses - 1)) : undefined; const shardSortedFn = coordinator_1.shardIntoSortedLabelBatches(options.numberLabelClasses, options.numberClients * options.numberDigitBatchesPerClient, options.numberDigitBatchesPerClient); const shard = options.shardingOptions.type === "unbalanced" ? coordinator_1.shardIntoUnbalancedBatchesPerBatch(shardSortedFn, coordinator_1.shardIntoSkedwedBatches(options.numberLabelClasses, options.numberClients, options.shardingOptions.skewFactorS), options.shardingOptions.sortedLabelSplit) : shardSortedFn; // Map values from the range [0, 255] to [0, 1]. If the model is the 2NN MLP, // the input needs to be flattened first; otherwise, if it's the CNN, it // shouldn't be flattened. const mapToUnitRange = coordinator_1.mapValuesToRange([0, 255], [0, 1]); const preprocessItems = options.modelName === "2NN" ? (items) => mapToUnitRange(coordinator_1.flattenItems(items)) : mapToUnitRange; // Both CNN and 2NN work with labels as one-hot vectors. const preprocessLabels = coordinator_1.oneHotLabels(options.numberLabelClasses); return { readRawData: () => read_binary_1.readRawData(options.dataset, options.environment), filter, shard, preprocess: { preprocessItems, preprocessLabels }, }; } //# sourceMappingURL=preprocess.js.map