federer
Version:
Experiments in asynchronous federated learning and decentralized learning
125 lines • 5.34 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.PreprocessPipeline = void 0;
const tslib_1 = require("tslib");
const path = tslib_1.__importStar(require("path"));
const fs = tslib_1.__importStar(require("fs"));
const assert = require("assert");
const tfjs_npy_node_1 = require("tfjs-npy-node");
const common_1 = require("../../common");
const distribution_stats_1 = require("./distribution-stats");
const PreprocessedDataPaths_1 = require("./PreprocessedDataPaths");
/**
* Represents a full preprocessing pipeline for Federated Learning.
*
* For FL, preprocessing must split the raw data into "shards"; each FL client's
* training set is equivalent to one shard.
*
* The preprocessing must also return an test set, used by the coordinator to
* evaluate the performance of the global model.
*/
class PreprocessPipeline {
constructor(experimentName, pipelineName, numberLabelClasses, pipeline) {
this.numberLabelClasses = numberLabelClasses;
this.pipeline = pipeline;
this.rootDir = path.join(common_1.absolutePath.data.processed(experimentName), pipelineName);
this.directories = {
train: path.join(this.rootDir, "train"),
test: path.join(this.rootDir, "test"),
stats: path.join(this.rootDir, "stats"),
};
this.pathsFile = path.join(this.rootDir, "cache.json");
Object.values(this.directories).forEach((dir) => common_1.mkdirp(dir));
}
run(allowReadFromCache = true) {
return common_1.tidySequentialAsync(async () => {
if (fs.existsSync(this.pathsFile) && allowReadFromCache) {
try {
return this.readCachedResults();
}
catch (err) {
return this.runAndCache();
}
}
else {
return this.runAndCache();
}
});
}
async runAndCache() {
const results = await this.runPipeline();
await this.cacheResults(results);
return results;
}
async runPipeline() {
const dataset = await this.pipeline.readRawData();
const [testSet, [clientTrainSets, dataDistribution]] = await Promise.all([
this.createTestSet(dataset.test),
this.createClientTrainSets(dataset.train),
]);
const [testSetFiles, clientTrainFiles] = await Promise.all([
testSet.save(this.directories.test, "evaluation"),
this.saveShards(clientTrainSets),
]);
return { testSet, testSetFiles, clientTrainFiles, dataDistribution };
}
async createTestSet(data) {
const filtered = (await this.pipeline.filter?.(data)) ?? data;
return this.preprocess(filtered);
}
async createClientTrainSets(data) {
const filtered = (await this.pipeline.filter?.(data)) ?? data;
const shards = await this.pipeline.shard(filtered);
// We run the data distribution analysis before preprocessing.
const distribution = distribution_stats_1.dataDistributionStats(shards.map((shard) => shard.labels.as1D()), this.numberLabelClasses);
return Promise.all([
Promise.all(shards.map((shard) => this.preprocess(shard))),
Promise.resolve(distribution),
]);
}
async preprocess(data) {
const [items, labels] = await Promise.all([
this.pipeline.preprocess.preprocessItems(data.items),
this.pipeline.preprocess.preprocessLabels(data.labels),
]);
return new common_1.DataSubset({ items, labels });
}
async readCachedResults() {
const paths = await PreprocessedDataPaths_1.RelativeDataPaths.load(this.pathsFile);
const cachedFiles = paths.toAbsolute().json;
const [testSet, distributionMatrix] = await Promise.all([
common_1.DataSubset.load(cachedFiles.testSetFiles),
tfjs_npy_node_1.npy.load(cachedFiles.dataDistribution.distributionMatrix),
]);
assert.strictEqual(distributionMatrix.rank, 2);
return {
clientTrainFiles: cachedFiles.clientTrainFiles,
testSetFiles: cachedFiles.testSetFiles,
testSet: testSet,
dataDistribution: {
distributionMatrix: distributionMatrix,
},
};
}
async cacheResults(results) {
const distributionMatrixPath = await this.saveDistributionMatrix(results.dataDistribution.distributionMatrix);
const paths = new PreprocessedDataPaths_1.AbsoluteDataPaths({
clientTrainFiles: results.clientTrainFiles,
testSetFiles: results.testSetFiles,
dataDistribution: {
distributionMatrix: distributionMatrixPath,
},
});
return paths.toRelative(this.rootDir).save(this.pathsFile);
}
async saveDistributionMatrix(matrix) {
const filepath = path.join(this.directories.stats, "distribution.npy");
await tfjs_npy_node_1.npy.save(filepath, matrix);
return filepath;
}
saveShards(shards) {
return Promise.all(shards.map((shard, id) => shard.save(this.directories.train, `shard-${id}`)));
}
}
exports.PreprocessPipeline = PreprocessPipeline;
//# sourceMappingURL=PreprocessPipeline.js.map