UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

158 lines 6.63 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.Experiment = void 0; const common_1 = require("../common"); const cli_1 = require("./cli"); const _1 = require("."); const DELTA_UPDATES = new Set([ "BatchedFedAsync", "FedSemiSync", ]); class Experiment { constructor(network, options, data, logger) { this.network = network; this.options = options; this.data = data; this.logger = logger; this.clientDelays = _1.getClientDelays(options.clientDelays ?? { type: "none" }, cli_1.CoordinatorCLIOptions.get("number-clients")); this.clientLoggerOptions = _1.getLoggerOptions("client", this.options.logging); this.modelPathOrURL = this.getPathOrUrl(this.data.modelFilePath); const unsubscribe = network.server.events.on("stopped", () => { this.logger.debug("server stopped"); unsubscribe(); }); network.clients.forEach((client, i) => { const unsubscribe = client.events.on("stopped", () => { this.logger.debug(`client ${i} stopped`); unsubscribe(); }); }); } getPathOrUrl(filepath) { const host = cli_1.CoordinatorCLIOptions.get("fileserver-host"); if (host !== undefined) { return encodeURI(`http://${host}/${common_1.relative(filepath)}`); } else { return filepath; } } async start() { await this.startServer(); await this.startClients(); } async stop() { this.network.server.stop(); this.network.ipc.clients.emit("stop"); await Promise.all([ this.network.server.stopped(), ...this.network.clients.map((client) => client.stopped()), ]); this.logger.debug("All nodes have returned 'stopped'"); } async startServer() { this.logger.info(`Starting ${this.options.serverOptions.server} server`); const server = this.network.server; this.logger.debug("Waiting to receive 'ready' from server"); await server.ready(); this.logger.debug("Received 'ready' from server, sending 'start'"); server.start(this.getServerOptions()); this.logger.debug("Waiting to receive 'started' from server"); await server.started(); this.logger.debug(`Started ${this.options.serverOptions.server} server`); } async startClients() { await Promise.all(this.network.clients.map(async (client, i) => { this.logger.debug(`Waiting to receive 'ready' from client ${i}`); await client.ready(); this.logger.debug(`Received 'ready' from client ${i}, sending 'start'`); client.start(this.getClientOptions(i)); this.logger.debug(`Waiting to receive 'started' from client ${i}`); await client.started(); this.logger.debug(`Client ${i} started`); })); } /** Returns the options for client with a given id. */ getClientOptions(id) { const dataFilePaths = this.data.preprocessed.clientTrainFiles[id]; return { id, deltaUpdates: DELTA_UPDATES.has(this.options.serverOptions.server), dataPathsOrURLs: { items: this.getPathOrUrl(dataFilePaths.items), labels: this.getPathOrUrl(dataFilePaths.labels), }, modelPathOrURL: this.modelPathOrURL, replyDelay: this.clientDelays[id], trainOptions: this.options.trainOptions, tensorflowVerbosity: this.options.tensorflowVerbosity ?? 0, loggerOptions: this.clientLoggerOptions, }; } getServerOptions() { const loggerOptions = _1.getLoggerOptions("server", this.options.logging); const numberClients = cli_1.CoordinatorCLIOptions.get("number-clients"); const base = { modelPathOrURL: this.modelPathOrURL, experimentName: this.data.experimentName, minimumNumberClientsForStart: numberClients, loggerOptions, instrument: { memoryUsage: this.options.instrument?.memoryUsage ?? false, uploadStaleness: this.options.instrument?.uploadStaleness ?? false, }, }; switch (this.options.serverOptions.server) { case "FedAvg": return { ...base, server: "FedAvg", numberClients, fractionOfClientsPerRound: this.options.serverOptions.fractionOfClientsPerRound, }; case "LiFedAvg": return { ...base, server: "LiFedAvg", numberClients, fractionOfClientsPerRound: this.options.serverOptions.fractionOfClientsPerRound, }; case "FedAsync": return { ...base, server: "FedAsync", epochDelay: this.options.serverOptions.epochDelay, alpha: this.options.serverOptions.alpha, staleness: this.options.serverOptions.staleness, }; case "BatchedFedAsync": return { ...base, server: "BatchedFedAsync", numberClientsPerRound: this.options.serverOptions.numberClientsPerRound, roundEndFraction: this.options.serverOptions.roundEndFraction, alpha: this.options.serverOptions.alpha, }; case "FedSemiSync": return { ...base, server: "FedSemiSync", numberClientsPerRound: this.options.serverOptions.numberClientsPerRound, roundEndFraction: this.options.serverOptions.roundEndFraction, a: this.options.serverOptions.a, alpha: this.options.serverOptions.alpha, }; case "FedCRDT": return { ...base, server: "FedCRDT", numberClientsPerRound: this.options.serverOptions.numberClientsPerRound, roundEndFraction: this.options.serverOptions.roundEndFraction, a: this.options.serverOptions.a, alpha: this.options.serverOptions.alpha, }; } } } exports.Experiment = Experiment; //# sourceMappingURL=Experiment.js.map