federer
Version:
Experiments in asynchronous federated learning and decentralized learning
85 lines • 3.97 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.SyncFLServer = void 0;
const common_1 = require("../../common");
const FLServer_1 = require("../FLServer");
/**
* Abstract class implementing a server that aggregates weights in rounds.
* Subclasses must implement the aggregation and operations on the round state.
*/
class SyncFLServer extends FLServer_1.FLServer {
constructor(server, initialWeights, options, logger) {
super(server, initialWeights, options, logger);
this.numberClientsPerRound = common_1.clamp(1, Math.floor(options.fractionOfClientsPerRound * options.numberClients), this.options.numberClients);
}
registerSocketListeners(socket) {
super.registerSocketListeners(socket);
socket.on("ready", () => {
// If all the clients we were waiting for in round 0 are ready, start FL.
if (this.shouldStartLearning()) {
this.endRound();
}
// If all the clients disconnected this is the first client that reconnects, end round.
if (this.shouldEndRound()) {
this.endRound();
}
});
// When the client uploads, incorporate its update into the state.
// If all clients have replied, end the round.
socket.on("upload", (message) => {
if (message.round !== this.currentRound.roundNumber) {
return;
}
const id = this.metadata.get(socket.id)?.clientId;
this.logger.debug(`Client ${id} uploaded new weights from round ${message.round}`);
this.updateRoundState(message, socket.id);
const numResponses = this.numberClientsPerRound - this.clientPool.numberTraining;
this.logger.info(`Round ${this.currentRound.roundNumber}: ${numResponses} / ${this.numberClientsPerRound} responses`);
if (this.shouldEndRound()) {
this.endRound();
}
});
// If this was the last client we were waiting for, end the round.
socket.on("disconnect", () => {
if (this.shouldEndRound()) {
this.endRound();
}
});
}
/** Returns whether we can start learning, if we haven't started already. */
shouldStartLearning() {
return (this.numberClients >= this.options.minimumNumberClientsForStart &&
this.currentRound.roundNumber === 0);
}
/** Returns whether we can end the current round. */
shouldEndRound() {
return (this.currentRound.roundNumber > 0 &&
this.clientPool.numberTraining === 0 &&
this.clientPool.numberAvailable >= this.numberClientsPerRound);
}
/** Ends the current round, and moves on to the next one. */
endRound() {
this.logger.info(`Ending round ${this.currentRound.roundNumber}`);
// Compute round aggregation
const roundAverage = common_1.tidy(() => this.getRoundAverage());
// Increment round state
const oldRoundNumber = this.currentRound.roundNumber;
this.incrementRoundState(roundAverage);
const newRoundNumber = this.currentRound.roundNumber;
// Sample clients for next round
const sampledClients = this.clientPool.sampleAvailable(this.numberClientsPerRound);
const sampledSockets = sampledClients.map((id) => this.getSocketWithID(id));
// Serialize aggregate
const serializedAverage = roundAverage.serializeSync();
roundAverage.dispose();
// Send round summary to coordinator
this.emitRoundSummary(`${this.serverName}-round-${oldRoundNumber}`, oldRoundNumber, serializedAverage);
// Send new weights to clients
this.sendDownloadMessage(sampledSockets, {
round: newRoundNumber,
weights: serializedAverage,
});
}
}
exports.SyncFLServer = SyncFLServer;
//# sourceMappingURL=SyncFLServer.js.map