UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

89 lines 3.67 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.FedAsyncServer = void 0; const common_1 = require("../../common"); const FLServer_1 = require("../FLServer"); const staleness_1 = require("./staleness"); class FedAsyncServer extends FLServer_1.FLServer { constructor() { super(...arguments); this.expectDeltaUpdates = false; } getInitialRoundState(initialWeights) { return { roundNumber: 0, weights: initialWeights, weightsAtRoundStart: initialWeights.serializeSync(), }; } /** * Registers listeners to react to events on the socket. * * @param socket Socket to the client */ registerSocketListeners(socket) { super.registerSocketListeners(socket); socket.on("ready", () => { if (this.numberClients === this.options.minimumNumberClientsForStart && this.currentRound.roundNumber === 0) { this.endRound(); void this.scheduler(); } }); socket.on("upload", (message) => { const id = this.metadata.get(socket.id)?.clientId; this.logger.info(`Client ${id} uploaded new weights from round ${message.round}`); this.updateRoundState(message); this.endRound(); }); } /** Incorporates an upload message into the state. */ updateRoundState(message) { common_1.assertNoLeakingTensors("updateRoundState", () => { const s = staleness_1.staleness(this.currentRound.roundNumber, message.round, this.options.staleness); const alphaT = this.options.alpha * s; const oldWeightedModelSum = this.currentRound.weights; this.currentRound.weights = common_1.tidy(() => this.currentRound.weights .mul(1 - alphaT) .add(common_1.Weights.deserialize(message.weights).mul(alphaT))); oldWeightedModelSum.dispose(); }); } /** Ends the current round, and moves on to the next one. */ endRound() { this.logger.info(`Ending round ${this.currentRound.roundNumber}`); const oldRoundNumber = this.currentRound.roundNumber; // We serialize synchronously to avoid possible race conditions between the // serialization and the next upload message. const serializedWeights = this.currentRound.weights.serializeSync(); this.currentRound = { roundNumber: oldRoundNumber + 1, weights: this.currentRound.weights, weightsAtRoundStart: serializedWeights, }; if (oldRoundNumber % 10 === 0) { this.emitRoundSummary(`fedasync-round-${oldRoundNumber}`, oldRoundNumber, serializedWeights); } } /** Periodically trigger training tasks on some clients */ scheduler() { setInterval(() => { if (this.clientPool.numberAvailable === 0) { this.logger.warn("All clients are already training"); return; } const [sampledId] = this.clientPool.sampleAvailable(1); const sampledClient = this.getSocketWithID(sampledId); this.sendDownloadMessage([sampledClient], this.createRoundMessage(this.currentRound)); }, this.options.epochDelay); } /** Creates a message that informs clients about a given round state. */ createRoundMessage(round) { return { round: round.roundNumber, weights: round.weightsAtRoundStart, }; } } exports.FedAsyncServer = FedAsyncServer; //# sourceMappingURL=FedAsyncServer.js.map