UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

68 lines 3.06 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.LeafCoordinator = void 0; const tslib_1 = require("tslib"); const assert = require("assert"); const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node")); const tfjs_npy_node_1 = require("tfjs-npy-node"); const coordinator_1 = require("../../../coordinator"); const common_1 = require("../../../common"); const cli_1 = require("../../../coordinator/cli"); const model_1 = require("./model"); const read_json_1 = require("./read-json"); class LeafCoordinator extends coordinator_1.Coordinator { constructor(options, ipc, logger, stopCondition) { checkLeafOptions(options); const modelOptions = { dataset: options.dataset, optimizer: options.model.optimizer, numberOutputClasses: options.numberLabelClasses, }; super(model_1.createModel(modelOptions), ipc, options, logger, stopCondition); this.experimentName = options.dataset; this.options = options; this.modelOptions = modelOptions; } async preprocessData() { const paths = await read_json_1.processRawData(this.options.dataset, cli_1.CoordinatorCLIOptions.get("environment"), cli_1.CoordinatorCLIOptions.get("number-clients"), this.options.numberLabelClasses); const testItems = await tfjs_npy_node_1.npy.load(paths.testpaths.items); const testLabels = await tfjs_npy_node_1.npy.load(paths.testpaths.labels); const fracSplice = 1 / 5; const size = getSize(this.options.dataset); const testSet = new common_1.DataSubset({ items: testItems.slice([0, 0], [testItems.shape[0] * fracSplice, size]), labels: testLabels.slice([0, 0], [testLabels.shape[0] * fracSplice, this.options.numberLabelClasses]), }); return { clientTrainFiles: paths.trainpaths, testSetFiles: paths.testpaths, testSet: testSet, dataDistribution: { distributionMatrix: tf.tensor2d([[0]]) }, }; } /** Possibly implement a run name function */ getRunName() { const base = coordinator_1.runName(this.options, cli_1.CoordinatorCLIOptions.getAll()); const labels = `labels${this.options.numberLabelClasses}`; const batches = `batches${this.options.numberRolesPerClient}`; return `${this.options.dataset}-${base}-${labels}-${batches}`; } } exports.LeafCoordinator = LeafCoordinator; function getSize(dataset) { switch (dataset) { case "shakespeare": return model_1.SEQUENCE_LENGTH; case "synthetic": return model_1.FEATURES; } } function checkLeafOptions(options) { assert(Number.isInteger(options.numberLabelClasses)); assert(options.numberLabelClasses > 0); /** 660 for shakespeare dataset */ assert(options.numberLabelClasses <= 660); assert(Number.isInteger(options.numberRolesPerClient)); assert(options.numberRolesPerClient > 0); } //# sourceMappingURL=LeafCoordinator.js.map