federer
Version:
Experiments in asynchronous federated learning and decentralized learning
36 lines • 1.74 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.FedSemiSyncServer = void 0;
const common_1 = require("../../common");
const SemiSyncFLServer_1 = require("./SemiSyncFLServer");
class FedSemiSyncServer extends SemiSyncFLServer_1.SemiSyncFLServer {
constructor() {
super(...arguments);
this.serverName = "fedsemisync";
this.expectDeltaUpdates = true;
}
getGlobalWeights() {
return this.currentRound.weights.clone(); // weights updated in place
}
/** Incorporates an upload message into the state. */
updateRoundState(socket, message) {
common_1.assertNoLeakingTensors("updateRoundState", () => {
const age = this.currentRound.roundNumber - message.round;
const staleness = 1 / (1 + age) ** this.options.a;
const n = [...this.metadata.values()]
.map((metadata) => metadata.numberDatapoints)
.reduce((x, y) => x + y);
const nk = this.getMetadata(socket.id).numberDatapoints;
const K = this.clientPool.size;
this.logger.info(`Updating. Weighting = ${K * (nk / n)}, age = ${age}, staleness = ${staleness}`);
const oldWeights = this.currentRound.weights;
this.currentRound.weights = common_1.tidy(() => this.currentRound.weights.add(common_1.Weights.deserialize(message.weights).mul(K * (nk / n) * staleness * this.options.alpha)));
if (message.round === this.currentRound.roundNumber) {
this.currentRound.replied.add(socket.id);
}
oldWeights.dispose();
});
}
}
exports.FedSemiSyncServer = FedSemiSyncServer;
//# sourceMappingURL=FedSemiSyncServer.js.map