federer
Version:
Experiments in asynchronous federated learning and decentralized learning
81 lines • 3.14 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.GridSearch = exports.ExperimentResults = void 0;
const tslib_1 = require("tslib");
const fs = tslib_1.__importStar(require("fs"));
const assert = require("assert");
const await_lock_1 = tslib_1.__importDefault(require("await-lock"));
const IPCServer_1 = require("./network/IPCServer");
const logging_1 = require("./options/logging");
const cli_1 = require("./cli");
/**
* Models the result of a completed experiment run.
*
* A run consists of the particular coordinator options that were used, and of
* the {@link RoundResults} of the best round.
*
* @typeParam Options - The type of coordinator options with which we ran the
* experiment.
*/
class ExperimentResults {
constructor(options, results) {
this.options = options;
this.results = results;
}
betterThan(that) {
if (this.results === undefined) {
return false;
}
else if (that.results === undefined) {
return true;
}
return (this.results.accuracy > that.results.accuracy &&
this.results.round.numberEpochs <= that.results.round.numberEpochs);
}
toJSON() {
return {
options: this.options,
results: this.results,
};
}
}
exports.ExperimentResults = ExperimentResults;
/**
* Abstract class defining the boilerplate for how to run grid search.
* Concrete subclasses must define how to run a single experiment.
*/
class GridSearch {
constructor(stopOptions, searchSpace) {
this.stopOptions = stopOptions;
this.searchSpace = searchSpace;
this.runLock = new await_lock_1.default();
assert(searchSpace.length > 0);
this.bestResults = new ExperimentResults(this.searchSpace[0]);
}
/** Run grid search. Returns a Promise of the best results. */
async run() {
if (this.runLock.acquired) {
throw new Error("Cannot run two grid searches at once!");
}
await this.runLock.acquireAsync();
const stopCondition = ({ accuracy, round }) => accuracy >= this.stopOptions.targetAccuracy ||
round.numberEpochs >= this.stopOptions.maxEpochs;
const logger = logging_1.getCoordinatorLogger(this.searchSpace[0].logging);
const ipc = IPCServer_1.IPCServer.create(cli_1.CoordinatorCLIOptions.get("clients-port"), cli_1.CoordinatorCLIOptions.get("server-port"), cli_1.CoordinatorCLIOptions.get("number-clients"), logger);
for (const options of this.searchSpace) {
const results = await this.runExperiment(options, ipc, logger, stopCondition);
if (results.betterThan(this.bestResults)) {
this.bestResults = results;
this.writeBestResults(results);
}
}
this.runLock.release();
logger.info("Grid search done");
return this.bestResults;
}
writeBestResults(results) {
fs.writeFileSync("bestResults.json", JSON.stringify(results.toJSON()));
}
}
exports.GridSearch = GridSearch;
//# sourceMappingURL=grid-search.js.map