federer
Version:
Experiments in asynchronous federated learning and decentralized learning
55 lines • 2.98 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.preprocess = void 0;
const tfjs_node_1 = require("@tensorflow/tfjs-node");
const common_1 = require("../../../common");
const coordinator_1 = require("../../../coordinator");
const read_binary_1 = require("./data/read-binary");
/**
* Preprocess the raw MNIST data into client shards and a test set.
*
* @param options Preprocessing options
* @returns The results of the preprocessing, with all data saved to disk
*/
async function preprocess(options) {
const pipeline = new coordinator_1.PreprocessPipeline(options.dataset, uniquePipelineName(options), options.numberLabelClasses, createPipelineFunctions(options));
const processed = await pipeline.run(options.allowReadFromCache);
const testSet = processed.testSet.withRanks({
labels: tfjs_node_1.Rank.R2,
items: options.modelName === "CNN" ? tfjs_node_1.Rank.R4 : tfjs_node_1.Rank.R2,
});
return { ...processed, testSet };
}
exports.preprocess = preprocess;
function uniquePipelineName(options) {
return (`${options.modelName}-${options.numberLabelClasses}-labels-split-into-` +
`${options.numberClients}-shards-of-${options.numberDigitBatchesPerClient}-batches`);
}
function createPipelineFunctions(options) {
// If we have less than 10 classes, we need to filter out labels that aren't
// part of the experiment; if we have exactly 10 classes, we don't do any
// filtering
const filter = options.numberLabelClasses !== 10
? coordinator_1.filterByLabels(common_1.range(0, options.numberLabelClasses - 1))
: undefined;
const shardSortedFn = coordinator_1.shardIntoSortedLabelBatches(options.numberLabelClasses, options.numberClients * options.numberDigitBatchesPerClient, options.numberDigitBatchesPerClient);
const shard = options.shardingOptions.type === "unbalanced"
? coordinator_1.shardIntoUnbalancedBatchesPerBatch(shardSortedFn, coordinator_1.shardIntoSkedwedBatches(options.numberLabelClasses, options.numberClients, options.shardingOptions.skewFactorS), options.shardingOptions.sortedLabelSplit)
: shardSortedFn;
// Map values from the range [0, 255] to [0, 1]. If the model is the 2NN MLP,
// the input needs to be flattened first; otherwise, if it's the CNN, it
// shouldn't be flattened.
const mapToUnitRange = coordinator_1.mapValuesToRange([0, 255], [0, 1]);
const preprocessItems = options.modelName === "2NN"
? (items) => mapToUnitRange(coordinator_1.flattenItems(items))
: mapToUnitRange;
// Both CNN and 2NN work with labels as one-hot vectors.
const preprocessLabels = coordinator_1.oneHotLabels(options.numberLabelClasses);
return {
readRawData: () => read_binary_1.readRawData(options.dataset, options.environment),
filter,
shard,
preprocess: { preprocessItems, preprocessLabels },
};
}
//# sourceMappingURL=preprocess.js.map