UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

74 lines 3 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.FedAvgServer = void 0; const common_1 = require("../../common"); const SyncFLServer_1 = require("./SyncFLServer"); /** * Implementation of standard FedAvg, as described in * {@link https://arxiv.org/abs/1602.05629}. */ class FedAvgServer extends SyncFLServer_1.SyncFLServer { constructor() { super(...arguments); this.serverName = "FedAvg"; this.expectDeltaUpdates = false; } registerSocketListeners(socket) { super.registerSocketListeners(socket); // Add client weights when they're ready to train socket.on("ready", () => { this.currentRound.clientWeights.set(socket.id, this.currentRound.initialWeights.clone()); }); // Remove the weights when they disconnect socket.on("disconnect", () => { this.currentRound.clientWeights.get(socket.id)?.dispose(); this.currentRound.clientWeights.delete(socket.id); }); } getInitialRoundState(initialWeights) { const clientWeights = new Map(); [...this.clientPool.clients].forEach((socketId, index) => { // To save a single clone, we re-use the given initial weights const weights = index === 0 ? initialWeights : initialWeights.clone(); clientWeights.set(socketId, weights); }); return { roundNumber: 0, initialWeights, clientWeights, }; } updateRoundState(message, socketId) { const old = this.currentRound.clientWeights.get(socketId); if (old === undefined) { throw new Error("Programmer error: got an upload on a socket that is not in the clientWeights: socket id is " + socketId); } this.currentRound.clientWeights.set(socketId, common_1.Weights.deserialize(message.weights)); old.dispose(); } getRoundAverage() { const totalNumberDatapoints = [...this.metadata.values()] .map((metadata) => metadata.numberDatapoints) .reduce((x, y) => x + y); return [...this.currentRound.clientWeights.entries()] .map(([id, weights]) => weights.mul(this.getMetadata(id).numberDatapoints)) .reduce((sum, clientWeights) => sum.add(clientWeights)) .divNoNan(totalNumberDatapoints); } incrementRoundState(roundAverage) { const newWeights = new Map(); for (const [id, oldWeights] of this.currentRound.clientWeights.entries()) { oldWeights.dispose(); newWeights.set(id, roundAverage.clone()); } this.currentRound.initialWeights.dispose(); this.currentRound = { roundNumber: this.currentRound.roundNumber + 1, initialWeights: roundAverage.clone(), clientWeights: newWeights, }; } } exports.FedAvgServer = FedAvgServer; //# sourceMappingURL=FedAvgServer.js.map