federer
Version:
Experiments in asynchronous federated learning and decentralized learning
74 lines • 3 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.FedAvgServer = void 0;
const common_1 = require("../../common");
const SyncFLServer_1 = require("./SyncFLServer");
/**
* Implementation of standard FedAvg, as described in
* {@link https://arxiv.org/abs/1602.05629}.
*/
class FedAvgServer extends SyncFLServer_1.SyncFLServer {
constructor() {
super(...arguments);
this.serverName = "FedAvg";
this.expectDeltaUpdates = false;
}
registerSocketListeners(socket) {
super.registerSocketListeners(socket);
// Add client weights when they're ready to train
socket.on("ready", () => {
this.currentRound.clientWeights.set(socket.id, this.currentRound.initialWeights.clone());
});
// Remove the weights when they disconnect
socket.on("disconnect", () => {
this.currentRound.clientWeights.get(socket.id)?.dispose();
this.currentRound.clientWeights.delete(socket.id);
});
}
getInitialRoundState(initialWeights) {
const clientWeights = new Map();
[...this.clientPool.clients].forEach((socketId, index) => {
// To save a single clone, we re-use the given initial weights
const weights = index === 0 ? initialWeights : initialWeights.clone();
clientWeights.set(socketId, weights);
});
return {
roundNumber: 0,
initialWeights,
clientWeights,
};
}
updateRoundState(message, socketId) {
const old = this.currentRound.clientWeights.get(socketId);
if (old === undefined) {
throw new Error("Programmer error: got an upload on a socket that is not in the clientWeights: socket id is " +
socketId);
}
this.currentRound.clientWeights.set(socketId, common_1.Weights.deserialize(message.weights));
old.dispose();
}
getRoundAverage() {
const totalNumberDatapoints = [...this.metadata.values()]
.map((metadata) => metadata.numberDatapoints)
.reduce((x, y) => x + y);
return [...this.currentRound.clientWeights.entries()]
.map(([id, weights]) => weights.mul(this.getMetadata(id).numberDatapoints))
.reduce((sum, clientWeights) => sum.add(clientWeights))
.divNoNan(totalNumberDatapoints);
}
incrementRoundState(roundAverage) {
const newWeights = new Map();
for (const [id, oldWeights] of this.currentRound.clientWeights.entries()) {
oldWeights.dispose();
newWeights.set(id, roundAverage.clone());
}
this.currentRound.initialWeights.dispose();
this.currentRound = {
roundNumber: this.currentRound.roundNumber + 1,
initialWeights: roundAverage.clone(),
clientWeights: newWeights,
};
}
}
exports.FedAvgServer = FedAvgServer;
//# sourceMappingURL=FedAvgServer.js.map