UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

80 lines 3.67 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.createFLServer = void 0; const tslib_1 = require("tslib"); const assert = require("assert"); const http = tslib_1.__importStar(require("http")); const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node")); const io = tslib_1.__importStar(require("socket.io-client")); const common_1 = require("../common"); const FedAsyncServer_1 = require("./async/FedAsyncServer"); const LiFedAvgServer_1 = require("./sync/LiFedAvgServer"); const FedAvgServer_1 = require("./sync/FedAvgServer"); const BatchedFedAsyncServer_1 = require("./semisync/BatchedFedAsyncServer"); const cli_1 = require("./cli"); const FedSemiSyncServer_1 = require("./semisync/FedSemiSyncServer"); const FedCRDT_1 = require("./semisync/FedCRDT"); async function createFLServer(options, port, logger) { const httpServer = http.createServer(); // According to the FedAvg paper, weights must be initialized randomly, // ideally with the same initial random weights on all client nodes (see // Figure 1, which compares an independent random initialization to a common // random initialization). // // To do this, we generate initialize the weights on the coordinator, and load // those initialized weights on the server. const model = await tf.loadLayersModel(common_1.PathOrURL.getTfIOHandler(options.modelPathOrURL)); const initialWeights = new common_1.Weights(model.getWeights()); const getServer = () => { switch (options.server) { case "FedAsync": return new FedAsyncServer_1.FedAsyncServer(httpServer, initialWeights, options, logger); case "BatchedFedAsync": return new BatchedFedAsyncServer_1.BatchedFedAsyncServer(httpServer, initialWeights, options, logger); case "FedSemiSync": return new FedSemiSyncServer_1.FedSemiSyncServer(httpServer, initialWeights, options, logger); case "FedCRDT": return new FedCRDT_1.FedCRDTServer(httpServer, initialWeights, options, logger); case "FedAvg": return new FedAvgServer_1.FedAvgServer(httpServer, initialWeights, options, logger); case "LiFedAvg": return new LiFedAvgServer_1.LiFedAvgServer(httpServer, initialWeights, options, logger); } }; const server = getServer(); return new Promise((resolve) => { httpServer.listen(port, () => resolve(server)); }); } exports.createFLServer = createFLServer; let server; let logger; const args = cli_1.getCLIArgs(); const coordinator = io.io(args["coordinator-url"]); coordinator.on("start", (options) => { assert(server === undefined); logger = common_1.createLogger(options.loggerOptions, "Server", "server.log"); logger.debug(`Received start options: ${JSON.stringify(options, null, " ")}`); void createFLServer(options, args.port, logger).then((newServer) => { server = newServer; server.events.on("roundEnd", (summary) => coordinator.emit("roundEnd", summary)); coordinator.emit("started"); }); }); coordinator.on("stop", () => { assert(server !== undefined); logger?.debug("Received stop message"); server.terminate(); server = undefined; coordinator.emit("stopped"); }); coordinator.once("kill", () => { logger?.debug("Received kill message"); coordinator.emit("killed", "received kill"); process.exit(0); }); coordinator.on("disconnect", (reason) => { logger?.warn(`Disconnected from coordinator for reason: '${reason}'`); }); coordinator.emit("ready"); //# sourceMappingURL=index.js.map