federer
Version:
Experiments in asynchronous federated learning and decentralized learning
68 lines • 3.06 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.LeafCoordinator = void 0;
const tslib_1 = require("tslib");
const assert = require("assert");
const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node"));
const tfjs_npy_node_1 = require("tfjs-npy-node");
const coordinator_1 = require("../../../coordinator");
const common_1 = require("../../../common");
const cli_1 = require("../../../coordinator/cli");
const model_1 = require("./model");
const read_json_1 = require("./read-json");
class LeafCoordinator extends coordinator_1.Coordinator {
constructor(options, ipc, logger, stopCondition) {
checkLeafOptions(options);
const modelOptions = {
dataset: options.dataset,
optimizer: options.model.optimizer,
numberOutputClasses: options.numberLabelClasses,
};
super(model_1.createModel(modelOptions), ipc, options, logger, stopCondition);
this.experimentName = options.dataset;
this.options = options;
this.modelOptions = modelOptions;
}
async preprocessData() {
const paths = await read_json_1.processRawData(this.options.dataset, cli_1.CoordinatorCLIOptions.get("environment"), cli_1.CoordinatorCLIOptions.get("number-clients"), this.options.numberLabelClasses);
const testItems = await tfjs_npy_node_1.npy.load(paths.testpaths.items);
const testLabels = await tfjs_npy_node_1.npy.load(paths.testpaths.labels);
const fracSplice = 1 / 5;
const size = getSize(this.options.dataset);
const testSet = new common_1.DataSubset({
items: testItems.slice([0, 0], [testItems.shape[0] * fracSplice, size]),
labels: testLabels.slice([0, 0], [testLabels.shape[0] * fracSplice, this.options.numberLabelClasses]),
});
return {
clientTrainFiles: paths.trainpaths,
testSetFiles: paths.testpaths,
testSet: testSet,
dataDistribution: { distributionMatrix: tf.tensor2d([[0]]) },
};
}
/** Possibly implement a run name function */
getRunName() {
const base = coordinator_1.runName(this.options, cli_1.CoordinatorCLIOptions.getAll());
const labels = `labels${this.options.numberLabelClasses}`;
const batches = `batches${this.options.numberRolesPerClient}`;
return `${this.options.dataset}-${base}-${labels}-${batches}`;
}
}
exports.LeafCoordinator = LeafCoordinator;
function getSize(dataset) {
switch (dataset) {
case "shakespeare":
return model_1.SEQUENCE_LENGTH;
case "synthetic":
return model_1.FEATURES;
}
}
function checkLeafOptions(options) {
assert(Number.isInteger(options.numberLabelClasses));
assert(options.numberLabelClasses > 0);
/** 660 for shakespeare dataset */
assert(options.numberLabelClasses <= 660);
assert(Number.isInteger(options.numberRolesPerClient));
assert(options.numberRolesPerClient > 0);
}
//# sourceMappingURL=LeafCoordinator.js.map