federer
Version:
Experiments in asynchronous federated learning and decentralized learning
80 lines • 3.67 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.createFLServer = void 0;
const tslib_1 = require("tslib");
const assert = require("assert");
const http = tslib_1.__importStar(require("http"));
const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node"));
const io = tslib_1.__importStar(require("socket.io-client"));
const common_1 = require("../common");
const FedAsyncServer_1 = require("./async/FedAsyncServer");
const LiFedAvgServer_1 = require("./sync/LiFedAvgServer");
const FedAvgServer_1 = require("./sync/FedAvgServer");
const BatchedFedAsyncServer_1 = require("./semisync/BatchedFedAsyncServer");
const cli_1 = require("./cli");
const FedSemiSyncServer_1 = require("./semisync/FedSemiSyncServer");
const FedCRDT_1 = require("./semisync/FedCRDT");
async function createFLServer(options, port, logger) {
const httpServer = http.createServer();
// According to the FedAvg paper, weights must be initialized randomly,
// ideally with the same initial random weights on all client nodes (see
// Figure 1, which compares an independent random initialization to a common
// random initialization).
//
// To do this, we generate initialize the weights on the coordinator, and load
// those initialized weights on the server.
const model = await tf.loadLayersModel(common_1.PathOrURL.getTfIOHandler(options.modelPathOrURL));
const initialWeights = new common_1.Weights(model.getWeights());
const getServer = () => {
switch (options.server) {
case "FedAsync":
return new FedAsyncServer_1.FedAsyncServer(httpServer, initialWeights, options, logger);
case "BatchedFedAsync":
return new BatchedFedAsyncServer_1.BatchedFedAsyncServer(httpServer, initialWeights, options, logger);
case "FedSemiSync":
return new FedSemiSyncServer_1.FedSemiSyncServer(httpServer, initialWeights, options, logger);
case "FedCRDT":
return new FedCRDT_1.FedCRDTServer(httpServer, initialWeights, options, logger);
case "FedAvg":
return new FedAvgServer_1.FedAvgServer(httpServer, initialWeights, options, logger);
case "LiFedAvg":
return new LiFedAvgServer_1.LiFedAvgServer(httpServer, initialWeights, options, logger);
}
};
const server = getServer();
return new Promise((resolve) => {
httpServer.listen(port, () => resolve(server));
});
}
exports.createFLServer = createFLServer;
let server;
let logger;
const args = cli_1.getCLIArgs();
const coordinator = io.io(args["coordinator-url"]);
coordinator.on("start", (options) => {
assert(server === undefined);
logger = common_1.createLogger(options.loggerOptions, "Server", "server.log");
logger.debug(`Received start options: ${JSON.stringify(options, null, " ")}`);
void createFLServer(options, args.port, logger).then((newServer) => {
server = newServer;
server.events.on("roundEnd", (summary) => coordinator.emit("roundEnd", summary));
coordinator.emit("started");
});
});
coordinator.on("stop", () => {
assert(server !== undefined);
logger?.debug("Received stop message");
server.terminate();
server = undefined;
coordinator.emit("stopped");
});
coordinator.once("kill", () => {
logger?.debug("Received kill message");
coordinator.emit("killed", "received kill");
process.exit(0);
});
coordinator.on("disconnect", (reason) => {
logger?.warn(`Disconnected from coordinator for reason: '${reason}'`);
});
coordinator.emit("ready");
//# sourceMappingURL=index.js.map