UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

100 lines 4.31 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.SemiSyncFLServer = void 0; const assert = require("assert"); const FLServer_1 = require("../FLServer"); class SemiSyncFLServer extends FLServer_1.FLServer { constructor(server, initialWeights, options, logger) { super(server, initialWeights, options, logger); assert(0 < this.options.roundEndFraction); assert(this.options.roundEndFraction <= 1); } registerSocketListeners(socket) { super.registerSocketListeners(socket); socket.on("ready", () => { if (this.shouldStartLearning()) { this.endRound(); } }); socket.on("disconnect", () => { // If it hasn't replied yet, pretend that we never asked it to train. // If it has replied, we keep its response. if (!this.currentRound.replied.has(socket.id)) { this.currentRound.requested.delete(socket.id); } if (this.shouldEndRound()) { this.endRound(); } }); socket.on("upload", (message) => { const id = this.metadata.get(socket.id)?.clientId ?? socket.id; this.logger.info(`Client ${id} uploaded new weights from round ${message.round}`); this.updateRoundState(socket, message); if (this.shouldEndRound()) { this.endRound(); } }); } getInitialRoundState(initialWeights) { return { roundNumber: 0, weights: initialWeights, requested: new Set(), replied: new Set(), }; } shouldStartLearning() { return (this.currentRound.roundNumber === 0 && this.numberClients === this.options.minimumNumberClientsForStart); } shouldEndRound() { const numRequests = this.currentRound.requested.size; const numReplies = this.currentRound.replied.size; // We might have 0 requests if all currently training nodes disconnect, or // if there were not enough available nodes. if (numRequests === 0) { return true; } const fractionOfResponses = numReplies / numRequests; assert(0 <= fractionOfResponses && fractionOfResponses <= 1); return fractionOfResponses >= this.options.roundEndFraction; } /** Ends the current round, and moves on to the next one. */ endRound() { const numRequests = this.currentRound.replied.size; const numReplies = this.currentRound.replied.size; const oldRoundNumber = this.currentRound.roundNumber; this.logger.info(`Ending round ${oldRoundNumber}: received ${numReplies}/${numRequests} responses`); // We serialize synchronously to avoid possible race conditions between the // serialization and the next upload message. const weights = this.getGlobalWeights(); const serializedWeights = weights.serializeSync(); const sampledIds = this.sampleClients(); const sampledClients = sampledIds.map((id) => this.getSocketWithID(id)); this.currentRound.weights.dispose(); this.currentRound = { roundNumber: oldRoundNumber + 1, weights: weights, requested: new Set(sampledIds), replied: new Set(), }; this.emitRoundSummary(`${this.serverName}-round-${oldRoundNumber}`, oldRoundNumber, serializedWeights); this.sendDownloadMessage(sampledClients, { round: this.currentRound.roundNumber, weights: serializedWeights, }); } sampleClients() { const nNeeded = this.options.numberClientsPerRound; const nAvailable = this.clientPool.numberAvailable; if (nNeeded > nAvailable) { this.logger.warn(`Should select ${nNeeded} clients, but only ${nAvailable} are available`); } if (nAvailable === 0) { this.logger.warn("There were 0 clients available. Starting a new round, which will end once we get a stale response."); } return this.clientPool.sampleAvailable(Math.min(nNeeded, nAvailable)); } } exports.SemiSyncFLServer = SemiSyncFLServer; //# sourceMappingURL=SemiSyncFLServer.js.map