UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

36 lines 1.74 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.FedSemiSyncServer = void 0; const common_1 = require("../../common"); const SemiSyncFLServer_1 = require("./SemiSyncFLServer"); class FedSemiSyncServer extends SemiSyncFLServer_1.SemiSyncFLServer { constructor() { super(...arguments); this.serverName = "fedsemisync"; this.expectDeltaUpdates = true; } getGlobalWeights() { return this.currentRound.weights.clone(); // weights updated in place } /** Incorporates an upload message into the state. */ updateRoundState(socket, message) { common_1.assertNoLeakingTensors("updateRoundState", () => { const age = this.currentRound.roundNumber - message.round; const staleness = 1 / (1 + age) ** this.options.a; const n = [...this.metadata.values()] .map((metadata) => metadata.numberDatapoints) .reduce((x, y) => x + y); const nk = this.getMetadata(socket.id).numberDatapoints; const K = this.clientPool.size; this.logger.info(`Updating. Weighting = ${K * (nk / n)}, age = ${age}, staleness = ${staleness}`); const oldWeights = this.currentRound.weights; this.currentRound.weights = common_1.tidy(() => this.currentRound.weights.add(common_1.Weights.deserialize(message.weights).mul(K * (nk / n) * staleness * this.options.alpha))); if (message.round === this.currentRound.roundNumber) { this.currentRound.replied.add(socket.id); } oldWeights.dispose(); }); } } exports.FedSemiSyncServer = FedSemiSyncServer; //# sourceMappingURL=FedSemiSyncServer.js.map