federer
Version:
Experiments in asynchronous federated learning and decentralized learning
93 lines • 4.14 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.FedCRDTServer = void 0;
const tslib_1 = require("tslib");
const assert_1 = tslib_1.__importDefault(require("assert"));
const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node"));
const common_1 = require("../../common");
const AverageWeights_1 = require("../utils/AverageWeights");
const SemiSyncFLServer_1 = require("./SemiSyncFLServer");
class FedCRDTServer extends SemiSyncFLServer_1.SemiSyncFLServer {
constructor(server, initialWeights, options, logger) {
super(server, initialWeights, options, logger);
this.serverName = "fedcrdt";
this.expectDeltaUpdates = false;
this.partialAverages = new Map();
this.shape = initialWeights.shapes;
}
endRound() {
// Overriding is a bit of a hack to get around the current architecture...
const alpha = this.options.alpha;
const newWeights = common_1.tidy(() => this.currentRound.weights.mul(1 - alpha));
this.currentRound.weights.dispose();
this.currentRound.weights = newWeights;
super.endRound();
}
getGlobalWeights() {
return common_1.tidy(() => {
console.log("THERE ARE", this.partialAverages.size, "ENTRIES");
const partialsSum = [...this.partialAverages.entries()]
.map(([round, avg]) => avg.average().mul(this.weightAge(round)))
.reduce((a, b) => a.add(b), common_1.Weights.zero(this.shape));
return this.currentRound.weights.add(partialsSum);
});
}
/** Incorporates an upload message into the state. */
updateRoundState(socket, message) {
console.log("update round state before", tf.memory().numTensors);
const avg = this.getPartialAverage(message.round);
const weights = common_1.Weights.deserialize(message.weights);
assert_1.default.deepStrictEqual(weights.shapes, this.shape);
const weightedWeights = weights.mul(this.weightAmountData(socket));
avg.add(weightedWeights);
weightedWeights.dispose();
weights.dispose();
if (avg.count === this.options.numberClientsPerRound) {
console.log("update round final client before", tf.memory().numTensors);
// Add round average to global weights
const newWeights = common_1.tidy(() => this.currentRound.weights.add(avg.average().mul(this.weightAge(message.round))));
this.currentRound.weights.dispose();
this.currentRound.weights = newWeights;
// Delete partial weights
avg.dispose();
this.partialAverages.delete(message.round);
console.log("update round final client after", tf.memory().numTensors);
}
if (message.round === this.currentRound.roundNumber) {
this.currentRound.replied.add(socket.id);
}
console.log("update round after", tf.memory().numTensors);
}
/**
* Computes the factor used to weigh contributions by the amount of data that
* they represent.
*/
weightAmountData(socket) {
const n = [...this.metadata.values()]
.map((metadata) => metadata.numberDatapoints)
.reduce((x, y) => x + y);
const nk = this.getMetadata(socket.id).numberDatapoints;
const K = this.clientPool.size;
return (K * nk) / n;
}
/** Computes the factor used to weigh contributions by their age. */
weightAge(tau) {
const alpha = this.options.alpha;
const age = this.currentRound.roundNumber - tau;
return alpha * (1 - alpha) ** age;
}
/** Get the partial average that we are maintaining for a given round. */
getPartialAverage(round) {
const avg = this.partialAverages.get(round);
if (avg === undefined) {
const newAvg = new AverageWeights_1.AverageWeights();
this.partialAverages.set(round, newAvg);
return newAvg;
}
else {
return avg;
}
}
}
exports.FedCRDTServer = FedCRDTServer;
//# sourceMappingURL=FedCRDT.js.map