federer
Version:
Experiments in asynchronous federated learning and decentralized learning
89 lines • 3.67 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.FedAsyncServer = void 0;
const common_1 = require("../../common");
const FLServer_1 = require("../FLServer");
const staleness_1 = require("./staleness");
class FedAsyncServer extends FLServer_1.FLServer {
constructor() {
super(...arguments);
this.expectDeltaUpdates = false;
}
getInitialRoundState(initialWeights) {
return {
roundNumber: 0,
weights: initialWeights,
weightsAtRoundStart: initialWeights.serializeSync(),
};
}
/**
* Registers listeners to react to events on the socket.
*
* @param socket Socket to the client
*/
registerSocketListeners(socket) {
super.registerSocketListeners(socket);
socket.on("ready", () => {
if (this.numberClients === this.options.minimumNumberClientsForStart &&
this.currentRound.roundNumber === 0) {
this.endRound();
void this.scheduler();
}
});
socket.on("upload", (message) => {
const id = this.metadata.get(socket.id)?.clientId;
this.logger.info(`Client ${id} uploaded new weights from round ${message.round}`);
this.updateRoundState(message);
this.endRound();
});
}
/** Incorporates an upload message into the state. */
updateRoundState(message) {
common_1.assertNoLeakingTensors("updateRoundState", () => {
const s = staleness_1.staleness(this.currentRound.roundNumber, message.round, this.options.staleness);
const alphaT = this.options.alpha * s;
const oldWeightedModelSum = this.currentRound.weights;
this.currentRound.weights = common_1.tidy(() => this.currentRound.weights
.mul(1 - alphaT)
.add(common_1.Weights.deserialize(message.weights).mul(alphaT)));
oldWeightedModelSum.dispose();
});
}
/** Ends the current round, and moves on to the next one. */
endRound() {
this.logger.info(`Ending round ${this.currentRound.roundNumber}`);
const oldRoundNumber = this.currentRound.roundNumber;
// We serialize synchronously to avoid possible race conditions between the
// serialization and the next upload message.
const serializedWeights = this.currentRound.weights.serializeSync();
this.currentRound = {
roundNumber: oldRoundNumber + 1,
weights: this.currentRound.weights,
weightsAtRoundStart: serializedWeights,
};
if (oldRoundNumber % 10 === 0) {
this.emitRoundSummary(`fedasync-round-${oldRoundNumber}`, oldRoundNumber, serializedWeights);
}
}
/** Periodically trigger training tasks on some clients */
scheduler() {
setInterval(() => {
if (this.clientPool.numberAvailable === 0) {
this.logger.warn("All clients are already training");
return;
}
const [sampledId] = this.clientPool.sampleAvailable(1);
const sampledClient = this.getSocketWithID(sampledId);
this.sendDownloadMessage([sampledClient], this.createRoundMessage(this.currentRound));
}, this.options.epochDelay);
}
/** Creates a message that informs clients about a given round state. */
createRoundMessage(round) {
return {
round: round.roundNumber,
weights: round.weightsAtRoundStart,
};
}
}
exports.FedAsyncServer = FedAsyncServer;
//# sourceMappingURL=FedAsyncServer.js.map