federer
Version:
Experiments in asynchronous federated learning and decentralized learning
100 lines • 4.31 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.SemiSyncFLServer = void 0;
const assert = require("assert");
const FLServer_1 = require("../FLServer");
class SemiSyncFLServer extends FLServer_1.FLServer {
constructor(server, initialWeights, options, logger) {
super(server, initialWeights, options, logger);
assert(0 < this.options.roundEndFraction);
assert(this.options.roundEndFraction <= 1);
}
registerSocketListeners(socket) {
super.registerSocketListeners(socket);
socket.on("ready", () => {
if (this.shouldStartLearning()) {
this.endRound();
}
});
socket.on("disconnect", () => {
// If it hasn't replied yet, pretend that we never asked it to train.
// If it has replied, we keep its response.
if (!this.currentRound.replied.has(socket.id)) {
this.currentRound.requested.delete(socket.id);
}
if (this.shouldEndRound()) {
this.endRound();
}
});
socket.on("upload", (message) => {
const id = this.metadata.get(socket.id)?.clientId ?? socket.id;
this.logger.info(`Client ${id} uploaded new weights from round ${message.round}`);
this.updateRoundState(socket, message);
if (this.shouldEndRound()) {
this.endRound();
}
});
}
getInitialRoundState(initialWeights) {
return {
roundNumber: 0,
weights: initialWeights,
requested: new Set(),
replied: new Set(),
};
}
shouldStartLearning() {
return (this.currentRound.roundNumber === 0 &&
this.numberClients === this.options.minimumNumberClientsForStart);
}
shouldEndRound() {
const numRequests = this.currentRound.requested.size;
const numReplies = this.currentRound.replied.size;
// We might have 0 requests if all currently training nodes disconnect, or
// if there were not enough available nodes.
if (numRequests === 0) {
return true;
}
const fractionOfResponses = numReplies / numRequests;
assert(0 <= fractionOfResponses && fractionOfResponses <= 1);
return fractionOfResponses >= this.options.roundEndFraction;
}
/** Ends the current round, and moves on to the next one. */
endRound() {
const numRequests = this.currentRound.replied.size;
const numReplies = this.currentRound.replied.size;
const oldRoundNumber = this.currentRound.roundNumber;
this.logger.info(`Ending round ${oldRoundNumber}: received ${numReplies}/${numRequests} responses`);
// We serialize synchronously to avoid possible race conditions between the
// serialization and the next upload message.
const weights = this.getGlobalWeights();
const serializedWeights = weights.serializeSync();
const sampledIds = this.sampleClients();
const sampledClients = sampledIds.map((id) => this.getSocketWithID(id));
this.currentRound.weights.dispose();
this.currentRound = {
roundNumber: oldRoundNumber + 1,
weights: weights,
requested: new Set(sampledIds),
replied: new Set(),
};
this.emitRoundSummary(`${this.serverName}-round-${oldRoundNumber}`, oldRoundNumber, serializedWeights);
this.sendDownloadMessage(sampledClients, {
round: this.currentRound.roundNumber,
weights: serializedWeights,
});
}
sampleClients() {
const nNeeded = this.options.numberClientsPerRound;
const nAvailable = this.clientPool.numberAvailable;
if (nNeeded > nAvailable) {
this.logger.warn(`Should select ${nNeeded} clients, but only ${nAvailable} are available`);
}
if (nAvailable === 0) {
this.logger.warn("There were 0 clients available. Starting a new round, which will end once we get a stale response.");
}
return this.clientPool.sampleAvailable(Math.min(nNeeded, nAvailable));
}
}
exports.SemiSyncFLServer = SemiSyncFLServer;
//# sourceMappingURL=SemiSyncFLServer.js.map