UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

109 lines 4.83 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.ClientWorker = void 0; const tslib_1 = require("tslib"); const assert = require("assert"); const worker_threads_1 = require("worker_threads"); const worker_threads_2 = require("worker_threads"); const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node")); const await_lock_1 = tslib_1.__importDefault(require("await-lock")); const Comlink = tslib_1.__importStar(require("comlink")); const node_adapter_1 = tslib_1.__importDefault(require("comlink/dist/umd/node-adapter")); const common_1 = require("../common"); const decay_1 = require("./decay"); const trainingLock = new await_lock_1.default(); /** * A {@link ClientWorker} is a process that "does the actual work". The * {@link Client} is responsible for anything related to networking, and the * {@link ClientWorker} does anything CPU-intensive, like training the model. * * The reason for this split is to keep the main thread of the client free, so * that it can respond to socket.io ping messages, for instance. If we ran the * client in a browser, it would also be very important to keep the main thread * free for things related to UI. * * The {@link Client} and {@link ClientWorker} communicate using * {@link https://github.com/GoogleChromeLabs/comlink | comlink}, which is a * wrapper API around `Worker.postMessage`. * * @see {@link https://maximekjaer.github.io/federer/docs/advanced/workers/} */ class ClientWorker { constructor(data, model, options) { this.data = data; this.model = model; this.options = options; assert(!worker_threads_1.isMainThread); this.logger = common_1.createLogger(options.loggerOptions, `Worker ${options.id}`, `client-worker-${options.id}.log`); this.schedule = decay_1.LearningRateSchedule.get(options.trainOptions.learningRateSchedule); } countNumberDataPoints() { return this.data.countNumberDatapoints(); } async train(weights, round) { this.logger.debug(`Worker received: ${new Date().getTime()}`); assert(!worker_threads_1.isMainThread, "Worker method called on main thread; should be called in worker thread"); if (trainingLock.acquired) { this.logger.warn("Tried to start straining, but the training lock was already acquired. " + "This may indicate that the server requested the client to train twice, " + "which may be a bug (or at the very least, overhead which can be eliminated)."); } await trainingLock.acquireAsync(); const newWeights = await this.doTrain(weights, round); trainingLock.release(); return newWeights; } async doTrain(serializedWeights, round) { const weights = common_1.Weights.deserialize(serializedWeights); this.model.setWeights(weights.weights); this.maybeSetLearningRate(round); await this.fit(); const newWeights = new common_1.Weights(this.model.getWeights()); if (this.options.deltaUpdates) { const delta = newWeights.sub(weights); const serialized = await delta.serialize(); weights.dispose(); delta.dispose(); return serialized; } else { weights.dispose(); return newWeights.serialize(); } } maybeSetLearningRate(round) { if (this.schedule === undefined) { return; } if (this.model.optimizer instanceof tf.SGDOptimizer) { const lr = this.schedule.decayedLearningRate(round); this.logger.debug(`Round ${round}: setting learning rate to ${lr}`); this.model.optimizer.setLearningRate(lr); } } async fit() { await this.model.fit(this.data.items, this.data.labels, { batchSize: this.options.trainOptions.batchSize, epochs: this.options.trainOptions.epochs, shuffle: true, verbose: this.options.tensorflowVerbosity, }); } } exports.ClientWorker = ClientWorker; const builder = { async build(options) { const [data, model] = await Promise.all([ common_1.DataSubset.load(options.dataPathsOrURLs), tf.loadLayersModel(common_1.PathOrURL.getTfIOHandler(options.modelPathOrURL)), ]); return Comlink.proxy(new ClientWorker(data, model, options)); }, }; if (worker_threads_2.parentPort === null || worker_threads_1.isMainThread) { throw new Error("Should only be launched as worker. " + "Did you accidentally import 'ClientWorker'? " + "If so, try to change the 'import' to an 'import type'."); } Comlink.expose(builder, node_adapter_1.default(worker_threads_2.parentPort)); //# sourceMappingURL=ClientWorker.js.map