UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

81 lines 3.14 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.GridSearch = exports.ExperimentResults = void 0; const tslib_1 = require("tslib"); const fs = tslib_1.__importStar(require("fs")); const assert = require("assert"); const await_lock_1 = tslib_1.__importDefault(require("await-lock")); const IPCServer_1 = require("./network/IPCServer"); const logging_1 = require("./options/logging"); const cli_1 = require("./cli"); /** * Models the result of a completed experiment run. * * A run consists of the particular coordinator options that were used, and of * the {@link RoundResults} of the best round. * * @typeParam Options - The type of coordinator options with which we ran the * experiment. */ class ExperimentResults { constructor(options, results) { this.options = options; this.results = results; } betterThan(that) { if (this.results === undefined) { return false; } else if (that.results === undefined) { return true; } return (this.results.accuracy > that.results.accuracy && this.results.round.numberEpochs <= that.results.round.numberEpochs); } toJSON() { return { options: this.options, results: this.results, }; } } exports.ExperimentResults = ExperimentResults; /** * Abstract class defining the boilerplate for how to run grid search. * Concrete subclasses must define how to run a single experiment. */ class GridSearch { constructor(stopOptions, searchSpace) { this.stopOptions = stopOptions; this.searchSpace = searchSpace; this.runLock = new await_lock_1.default(); assert(searchSpace.length > 0); this.bestResults = new ExperimentResults(this.searchSpace[0]); } /** Run grid search. Returns a Promise of the best results. */ async run() { if (this.runLock.acquired) { throw new Error("Cannot run two grid searches at once!"); } await this.runLock.acquireAsync(); const stopCondition = ({ accuracy, round }) => accuracy >= this.stopOptions.targetAccuracy || round.numberEpochs >= this.stopOptions.maxEpochs; const logger = logging_1.getCoordinatorLogger(this.searchSpace[0].logging); const ipc = IPCServer_1.IPCServer.create(cli_1.CoordinatorCLIOptions.get("clients-port"), cli_1.CoordinatorCLIOptions.get("server-port"), cli_1.CoordinatorCLIOptions.get("number-clients"), logger); for (const options of this.searchSpace) { const results = await this.runExperiment(options, ipc, logger, stopCondition); if (results.betterThan(this.bestResults)) { this.bestResults = results; this.writeBestResults(results); } } this.runLock.release(); logger.info("Grid search done"); return this.bestResults; } writeBestResults(results) { fs.writeFileSync("bestResults.json", JSON.stringify(results.toJSON())); } } exports.GridSearch = GridSearch; //# sourceMappingURL=grid-search.js.map