UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

69 lines 2.24 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.FEATURES = exports.SEQUENCE_LENGTH = exports.getModel = exports.createModel = void 0; const tslib_1 = require("tslib"); const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node")); const coordinator_1 = require("../../../coordinator"); function createModel(options) { const model = getModel(options); model.compile({ optimizer: coordinator_1.getOptimizer(options.optimizer), loss: "categoricalCrossentropy", metrics: ["accuracy"], }); return model; } exports.createModel = createModel; function getModel(options) { switch (options.dataset) { case "shakespeare": return getLSTMModel(options.numberOutputClasses, options.recurrentInitializer); case "synthetic": return getLogRegModel(options.numberOutputClasses); } } exports.getModel = getModel; const EMBEDDING_SIZE = 8; exports.SEQUENCE_LENGTH = 80; const getLSTMModel = (numOutputClasses, recurrentInitializer) => tf.sequential({ layers: [ tf.layers.embedding({ name: "Embedding", inputDim: numOutputClasses, outputDim: EMBEDDING_SIZE, inputLength: exports.SEQUENCE_LENGTH, maskZero: false, }), tf.layers.lstm({ name: "LSTM1", units: 256, recurrentActivation: "sigmoid", recurrentInitializer: recurrentInitializer, returnSequences: true, }), tf.layers.lstm({ name: "LSTM2", units: 256, recurrentActivation: "sigmoid", recurrentInitializer: recurrentInitializer, returnSequences: false, }), tf.layers.dense({ name: "OutputLayer", units: numOutputClasses, activation: "softmax", }), ], }); exports.FEATURES = 60; const getLogRegModel = (numberOutputClasses) => tf.sequential({ layers: [ tf.layers.dense({ name: "OutputLayer", inputShape: [exports.FEATURES], units: numberOutputClasses, activation: "softmax", }), ], }); //# sourceMappingURL=model.js.map