UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

93 lines 4.14 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.FedCRDTServer = void 0; const tslib_1 = require("tslib"); const assert_1 = tslib_1.__importDefault(require("assert")); const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node")); const common_1 = require("../../common"); const AverageWeights_1 = require("../utils/AverageWeights"); const SemiSyncFLServer_1 = require("./SemiSyncFLServer"); class FedCRDTServer extends SemiSyncFLServer_1.SemiSyncFLServer { constructor(server, initialWeights, options, logger) { super(server, initialWeights, options, logger); this.serverName = "fedcrdt"; this.expectDeltaUpdates = false; this.partialAverages = new Map(); this.shape = initialWeights.shapes; } endRound() { // Overriding is a bit of a hack to get around the current architecture... const alpha = this.options.alpha; const newWeights = common_1.tidy(() => this.currentRound.weights.mul(1 - alpha)); this.currentRound.weights.dispose(); this.currentRound.weights = newWeights; super.endRound(); } getGlobalWeights() { return common_1.tidy(() => { console.log("THERE ARE", this.partialAverages.size, "ENTRIES"); const partialsSum = [...this.partialAverages.entries()] .map(([round, avg]) => avg.average().mul(this.weightAge(round))) .reduce((a, b) => a.add(b), common_1.Weights.zero(this.shape)); return this.currentRound.weights.add(partialsSum); }); } /** Incorporates an upload message into the state. */ updateRoundState(socket, message) { console.log("update round state before", tf.memory().numTensors); const avg = this.getPartialAverage(message.round); const weights = common_1.Weights.deserialize(message.weights); assert_1.default.deepStrictEqual(weights.shapes, this.shape); const weightedWeights = weights.mul(this.weightAmountData(socket)); avg.add(weightedWeights); weightedWeights.dispose(); weights.dispose(); if (avg.count === this.options.numberClientsPerRound) { console.log("update round final client before", tf.memory().numTensors); // Add round average to global weights const newWeights = common_1.tidy(() => this.currentRound.weights.add(avg.average().mul(this.weightAge(message.round)))); this.currentRound.weights.dispose(); this.currentRound.weights = newWeights; // Delete partial weights avg.dispose(); this.partialAverages.delete(message.round); console.log("update round final client after", tf.memory().numTensors); } if (message.round === this.currentRound.roundNumber) { this.currentRound.replied.add(socket.id); } console.log("update round after", tf.memory().numTensors); } /** * Computes the factor used to weigh contributions by the amount of data that * they represent. */ weightAmountData(socket) { const n = [...this.metadata.values()] .map((metadata) => metadata.numberDatapoints) .reduce((x, y) => x + y); const nk = this.getMetadata(socket.id).numberDatapoints; const K = this.clientPool.size; return (K * nk) / n; } /** Computes the factor used to weigh contributions by their age. */ weightAge(tau) { const alpha = this.options.alpha; const age = this.currentRound.roundNumber - tau; return alpha * (1 - alpha) ** age; } /** Get the partial average that we are maintaining for a given round. */ getPartialAverage(round) { const avg = this.partialAverages.get(round); if (avg === undefined) { const newAvg = new AverageWeights_1.AverageWeights(); this.partialAverages.set(round, newAvg); return newAvg; } else { return avg; } } } exports.FedCRDTServer = FedCRDTServer; //# sourceMappingURL=FedCRDT.js.map