federer
Version:
Experiments in asynchronous federated learning and decentralized learning
69 lines • 2.24 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.FEATURES = exports.SEQUENCE_LENGTH = exports.getModel = exports.createModel = void 0;
const tslib_1 = require("tslib");
const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node"));
const coordinator_1 = require("../../../coordinator");
function createModel(options) {
const model = getModel(options);
model.compile({
optimizer: coordinator_1.getOptimizer(options.optimizer),
loss: "categoricalCrossentropy",
metrics: ["accuracy"],
});
return model;
}
exports.createModel = createModel;
function getModel(options) {
switch (options.dataset) {
case "shakespeare":
return getLSTMModel(options.numberOutputClasses, options.recurrentInitializer);
case "synthetic":
return getLogRegModel(options.numberOutputClasses);
}
}
exports.getModel = getModel;
const EMBEDDING_SIZE = 8;
exports.SEQUENCE_LENGTH = 80;
const getLSTMModel = (numOutputClasses, recurrentInitializer) => tf.sequential({
layers: [
tf.layers.embedding({
name: "Embedding",
inputDim: numOutputClasses,
outputDim: EMBEDDING_SIZE,
inputLength: exports.SEQUENCE_LENGTH,
maskZero: false,
}),
tf.layers.lstm({
name: "LSTM1",
units: 256,
recurrentActivation: "sigmoid",
recurrentInitializer: recurrentInitializer,
returnSequences: true,
}),
tf.layers.lstm({
name: "LSTM2",
units: 256,
recurrentActivation: "sigmoid",
recurrentInitializer: recurrentInitializer,
returnSequences: false,
}),
tf.layers.dense({
name: "OutputLayer",
units: numOutputClasses,
activation: "softmax",
}),
],
});
exports.FEATURES = 60;
const getLogRegModel = (numberOutputClasses) => tf.sequential({
layers: [
tf.layers.dense({
name: "OutputLayer",
inputShape: [exports.FEATURES],
units: numberOutputClasses,
activation: "softmax",
}),
],
});
//# sourceMappingURL=model.js.map