federer
Version:
Experiments in asynchronous federated learning and decentralized learning
154 lines • 7.13 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.Coordinator = void 0;
const tslib_1 = require("tslib");
const assert = require("assert");
const path = tslib_1.__importStar(require("path"));
const fs = tslib_1.__importStar(require("fs"));
const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node"));
const common_1 = require("../common");
const coordinator_1 = require("./options/coordinator");
const TensorBoardOutput_1 = require("./evaluation/TensorBoardOutput");
const Evaluator_1 = require("./evaluation/Evaluator");
const Experiment_1 = require("./Experiment");
const cli_1 = require("./cli");
const config_1 = require("./aws/config");
/**
* The `Coordinator` class is the central piece of the coordinator node. It
* manages:
*
* - Preprocessing,
* - Starting nodes in the correct location,
* - Setting up communication between nodes,
* - Listening to results from the FL server and evaluating them,
* - Writing results to TensorBoard.
*
* To do all these tasks, the coordinator uses a series of helper classes and
* functions, including {@link TensorBoardOutput}, {@link ExperimentBuilder},
* {@link Experiment}, {@link Evaluator}, ...
*/
class Coordinator {
constructor(model, ipc, options, logger, stopCondition) {
this.model = model;
this.ipc = ipc;
this.options = options;
this.logger = logger;
this.stopCondition = stopCondition;
coordinator_1.checkCoordinatorOptions(options);
this.runName = this.getRunName();
this.tensorboard = new TensorBoardOutput_1.TensorBoardOutput(this.runName, options, logger);
}
async run() {
this.logger.info("Running experiment with options", this.options);
this.logger.info("Saving model and preprocessing data");
const [preprocessed, modelFilePath] = await Promise.all([
this.preprocessData(),
this.saveInitialModel(),
]);
this.printDataDistributionStats(preprocessed.dataDistribution);
this.tensorboard.writeDataVisualization(preprocessed.dataDistribution.distributionMatrix);
this.logger.info("Data distrbution posted on tensorboard");
preprocessed.dataDistribution.distributionMatrix.dispose();
this.logger.info("Waiting for all nodes to be connected...");
this.network = await this.ipc.networkReady();
const experimentData = {
experimentName: this.experimentName,
modelFilePath,
preprocessed,
};
const experiment = new Experiment_1.Experiment(this.network, this.options, experimentData, this.logger);
const evaluator = new Evaluator_1.Evaluator(this.model, preprocessed.testSet, this.options);
const summaryHandler = (summary) => {
this.tensorboard.writeRoundSummary(summary);
void evaluator
.evaluate(summary)
.then((results) => {
this.logger.info(`Accuracy at round ${results.round.roundNumber}: ${results.accuracy}`);
this.tensorboard.writeEvaluationResults(results);
return this.checkStopCondition(experiment, results);
})
.catch((err) => this.logger.debug(`Did not write evaluation result because ${err}`));
};
this.unsubscribeNetworkEvents = () => {
this.network?.server.socket.removeListener("roundEnd", summaryHandler);
};
this.network.server.socket.on("roundEnd", summaryHandler);
this.logger.info("Starting experiment");
await experiment.start();
return new Promise((resolve) => {
assert(this.resolveRun === undefined);
this.resolveRun = resolve;
});
}
/**
* Saves `this.model` to a file.
*
* @returns A promise of the path to the saved model file
*/
async saveInitialModel() {
const modelFolder = common_1.absolutePath.models.coordinator(this.experimentName, this.runName);
common_1.mkdirp(modelFolder);
await this.model.save(tf.io.fileSystem(modelFolder), {
includeOptimizer: true,
});
const modelFilePath = path.join(modelFolder, "model.json");
// These two files should exist after saving.
// See https://js.tensorflow.org/api/latest/
assert(fs.existsSync(modelFilePath));
assert(fs.existsSync(path.join(modelFolder, "weights.bin")));
return modelFilePath;
}
printDataDistributionStats(stats) {
if (this.options.debug?.printDataDistribution) {
// This console log is only used for debug purposes.
// TODO: use tf.summary.image or something like that once it's implemented
// eslint-disable-next-line no-console
console.table(stats.distributionMatrix.arraySync());
}
}
async checkStopCondition(experiment, results) {
if (this.resolveRun === undefined) {
// The run has already finished (the stop condition has been reached), but
// we received a latent IPC message. In this case, we don't need to check
// the stop condition again, as the run has already been stopped.
return;
}
if (this.stopCondition?.(results) ?? false) {
this.logger.info("Stopping experiment...");
this.unsubscribeNetworkEvents?.();
// Mark the experiment as stopped by setting this.resolveRun to undefined
// This prevents multiple stop messages from being sent.
const resolve = this.resolveRun;
this.resolveRun = undefined;
await Promise.all([experiment.stop(), this.maybeBackupResults()]);
this.logger.info("Experiment stopped");
resolve(results);
}
}
async maybeBackupResults() {
if (cli_1.CoordinatorCLIOptions.get("environment") === "aws") {
const config = config_1.getAWSConfig();
const s3 = new common_1.S3({ region: config.region });
const items = await fs.promises.readdir(this.tensorboard.resultsPath);
const paths = items.map((item) => path.join(this.tensorboard.resultsPath, item));
const files = paths.filter((item) => fs.lstatSync(item).isFile());
await Promise.all(files.map(async (file) => {
this.logger.info(`Uploading ${file} to S3`);
await s3.upload(file, config.bucketName);
this.logger.info(`Done uploading ${file} to S3`);
}));
this.logger.info("Uploaded results file to S3");
}
}
cleanUp() {
// Not currently used. We prefer to "soft reset" children rather than a hard
// kill.
assert(this.network !== undefined, "Cannot be called before the experiment has been started");
this.network.tracker.killChildProcesses();
this.network.tracker.removeListeners();
this.network.ipc.close();
return this.network.tracker.allChildrenExited();
}
}
exports.Coordinator = Coordinator;
//# sourceMappingURL=Coordinator.js.map