UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

125 lines 5.34 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.PreprocessPipeline = void 0; const tslib_1 = require("tslib"); const path = tslib_1.__importStar(require("path")); const fs = tslib_1.__importStar(require("fs")); const assert = require("assert"); const tfjs_npy_node_1 = require("tfjs-npy-node"); const common_1 = require("../../common"); const distribution_stats_1 = require("./distribution-stats"); const PreprocessedDataPaths_1 = require("./PreprocessedDataPaths"); /** * Represents a full preprocessing pipeline for Federated Learning. * * For FL, preprocessing must split the raw data into "shards"; each FL client's * training set is equivalent to one shard. * * The preprocessing must also return an test set, used by the coordinator to * evaluate the performance of the global model. */ class PreprocessPipeline { constructor(experimentName, pipelineName, numberLabelClasses, pipeline) { this.numberLabelClasses = numberLabelClasses; this.pipeline = pipeline; this.rootDir = path.join(common_1.absolutePath.data.processed(experimentName), pipelineName); this.directories = { train: path.join(this.rootDir, "train"), test: path.join(this.rootDir, "test"), stats: path.join(this.rootDir, "stats"), }; this.pathsFile = path.join(this.rootDir, "cache.json"); Object.values(this.directories).forEach((dir) => common_1.mkdirp(dir)); } run(allowReadFromCache = true) { return common_1.tidySequentialAsync(async () => { if (fs.existsSync(this.pathsFile) && allowReadFromCache) { try { return this.readCachedResults(); } catch (err) { return this.runAndCache(); } } else { return this.runAndCache(); } }); } async runAndCache() { const results = await this.runPipeline(); await this.cacheResults(results); return results; } async runPipeline() { const dataset = await this.pipeline.readRawData(); const [testSet, [clientTrainSets, dataDistribution]] = await Promise.all([ this.createTestSet(dataset.test), this.createClientTrainSets(dataset.train), ]); const [testSetFiles, clientTrainFiles] = await Promise.all([ testSet.save(this.directories.test, "evaluation"), this.saveShards(clientTrainSets), ]); return { testSet, testSetFiles, clientTrainFiles, dataDistribution }; } async createTestSet(data) { const filtered = (await this.pipeline.filter?.(data)) ?? data; return this.preprocess(filtered); } async createClientTrainSets(data) { const filtered = (await this.pipeline.filter?.(data)) ?? data; const shards = await this.pipeline.shard(filtered); // We run the data distribution analysis before preprocessing. const distribution = distribution_stats_1.dataDistributionStats(shards.map((shard) => shard.labels.as1D()), this.numberLabelClasses); return Promise.all([ Promise.all(shards.map((shard) => this.preprocess(shard))), Promise.resolve(distribution), ]); } async preprocess(data) { const [items, labels] = await Promise.all([ this.pipeline.preprocess.preprocessItems(data.items), this.pipeline.preprocess.preprocessLabels(data.labels), ]); return new common_1.DataSubset({ items, labels }); } async readCachedResults() { const paths = await PreprocessedDataPaths_1.RelativeDataPaths.load(this.pathsFile); const cachedFiles = paths.toAbsolute().json; const [testSet, distributionMatrix] = await Promise.all([ common_1.DataSubset.load(cachedFiles.testSetFiles), tfjs_npy_node_1.npy.load(cachedFiles.dataDistribution.distributionMatrix), ]); assert.strictEqual(distributionMatrix.rank, 2); return { clientTrainFiles: cachedFiles.clientTrainFiles, testSetFiles: cachedFiles.testSetFiles, testSet: testSet, dataDistribution: { distributionMatrix: distributionMatrix, }, }; } async cacheResults(results) { const distributionMatrixPath = await this.saveDistributionMatrix(results.dataDistribution.distributionMatrix); const paths = new PreprocessedDataPaths_1.AbsoluteDataPaths({ clientTrainFiles: results.clientTrainFiles, testSetFiles: results.testSetFiles, dataDistribution: { distributionMatrix: distributionMatrixPath, }, }); return paths.toRelative(this.rootDir).save(this.pathsFile); } async saveDistributionMatrix(matrix) { const filepath = path.join(this.directories.stats, "distribution.npy"); await tfjs_npy_node_1.npy.save(filepath, matrix); return filepath; } saveShards(shards) { return Promise.all(shards.map((shard, id) => shard.save(this.directories.train, `shard-${id}`))); } } exports.PreprocessPipeline = PreprocessPipeline; //# sourceMappingURL=PreprocessPipeline.js.map