UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

85 lines 3.97 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.SyncFLServer = void 0; const common_1 = require("../../common"); const FLServer_1 = require("../FLServer"); /** * Abstract class implementing a server that aggregates weights in rounds. * Subclasses must implement the aggregation and operations on the round state. */ class SyncFLServer extends FLServer_1.FLServer { constructor(server, initialWeights, options, logger) { super(server, initialWeights, options, logger); this.numberClientsPerRound = common_1.clamp(1, Math.floor(options.fractionOfClientsPerRound * options.numberClients), this.options.numberClients); } registerSocketListeners(socket) { super.registerSocketListeners(socket); socket.on("ready", () => { // If all the clients we were waiting for in round 0 are ready, start FL. if (this.shouldStartLearning()) { this.endRound(); } // If all the clients disconnected this is the first client that reconnects, end round. if (this.shouldEndRound()) { this.endRound(); } }); // When the client uploads, incorporate its update into the state. // If all clients have replied, end the round. socket.on("upload", (message) => { if (message.round !== this.currentRound.roundNumber) { return; } const id = this.metadata.get(socket.id)?.clientId; this.logger.debug(`Client ${id} uploaded new weights from round ${message.round}`); this.updateRoundState(message, socket.id); const numResponses = this.numberClientsPerRound - this.clientPool.numberTraining; this.logger.info(`Round ${this.currentRound.roundNumber}: ${numResponses} / ${this.numberClientsPerRound} responses`); if (this.shouldEndRound()) { this.endRound(); } }); // If this was the last client we were waiting for, end the round. socket.on("disconnect", () => { if (this.shouldEndRound()) { this.endRound(); } }); } /** Returns whether we can start learning, if we haven't started already. */ shouldStartLearning() { return (this.numberClients >= this.options.minimumNumberClientsForStart && this.currentRound.roundNumber === 0); } /** Returns whether we can end the current round. */ shouldEndRound() { return (this.currentRound.roundNumber > 0 && this.clientPool.numberTraining === 0 && this.clientPool.numberAvailable >= this.numberClientsPerRound); } /** Ends the current round, and moves on to the next one. */ endRound() { this.logger.info(`Ending round ${this.currentRound.roundNumber}`); // Compute round aggregation const roundAverage = common_1.tidy(() => this.getRoundAverage()); // Increment round state const oldRoundNumber = this.currentRound.roundNumber; this.incrementRoundState(roundAverage); const newRoundNumber = this.currentRound.roundNumber; // Sample clients for next round const sampledClients = this.clientPool.sampleAvailable(this.numberClientsPerRound); const sampledSockets = sampledClients.map((id) => this.getSocketWithID(id)); // Serialize aggregate const serializedAverage = roundAverage.serializeSync(); roundAverage.dispose(); // Send round summary to coordinator this.emitRoundSummary(`${this.serverName}-round-${oldRoundNumber}`, oldRoundNumber, serializedAverage); // Send new weights to clients this.sendDownloadMessage(sampledSockets, { round: newRoundNumber, weights: serializedAverage, }); } } exports.SyncFLServer = SyncFLServer; //# sourceMappingURL=SyncFLServer.js.map