federer
Version:
Experiments in asynchronous federated learning and decentralized learning
109 lines • 4.83 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.ClientWorker = void 0;
const tslib_1 = require("tslib");
const assert = require("assert");
const worker_threads_1 = require("worker_threads");
const worker_threads_2 = require("worker_threads");
const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node"));
const await_lock_1 = tslib_1.__importDefault(require("await-lock"));
const Comlink = tslib_1.__importStar(require("comlink"));
const node_adapter_1 = tslib_1.__importDefault(require("comlink/dist/umd/node-adapter"));
const common_1 = require("../common");
const decay_1 = require("./decay");
const trainingLock = new await_lock_1.default();
/**
* A {@link ClientWorker} is a process that "does the actual work". The
* {@link Client} is responsible for anything related to networking, and the
* {@link ClientWorker} does anything CPU-intensive, like training the model.
*
* The reason for this split is to keep the main thread of the client free, so
* that it can respond to socket.io ping messages, for instance. If we ran the
* client in a browser, it would also be very important to keep the main thread
* free for things related to UI.
*
* The {@link Client} and {@link ClientWorker} communicate using
* {@link https://github.com/GoogleChromeLabs/comlink | comlink}, which is a
* wrapper API around `Worker.postMessage`.
*
* @see {@link https://maximekjaer.github.io/federer/docs/advanced/workers/}
*/
class ClientWorker {
constructor(data, model, options) {
this.data = data;
this.model = model;
this.options = options;
assert(!worker_threads_1.isMainThread);
this.logger = common_1.createLogger(options.loggerOptions, `Worker ${options.id}`, `client-worker-${options.id}.log`);
this.schedule = decay_1.LearningRateSchedule.get(options.trainOptions.learningRateSchedule);
}
countNumberDataPoints() {
return this.data.countNumberDatapoints();
}
async train(weights, round) {
this.logger.debug(`Worker received: ${new Date().getTime()}`);
assert(!worker_threads_1.isMainThread, "Worker method called on main thread; should be called in worker thread");
if (trainingLock.acquired) {
this.logger.warn("Tried to start straining, but the training lock was already acquired. " +
"This may indicate that the server requested the client to train twice, " +
"which may be a bug (or at the very least, overhead which can be eliminated).");
}
await trainingLock.acquireAsync();
const newWeights = await this.doTrain(weights, round);
trainingLock.release();
return newWeights;
}
async doTrain(serializedWeights, round) {
const weights = common_1.Weights.deserialize(serializedWeights);
this.model.setWeights(weights.weights);
this.maybeSetLearningRate(round);
await this.fit();
const newWeights = new common_1.Weights(this.model.getWeights());
if (this.options.deltaUpdates) {
const delta = newWeights.sub(weights);
const serialized = await delta.serialize();
weights.dispose();
delta.dispose();
return serialized;
}
else {
weights.dispose();
return newWeights.serialize();
}
}
maybeSetLearningRate(round) {
if (this.schedule === undefined) {
return;
}
if (this.model.optimizer instanceof tf.SGDOptimizer) {
const lr = this.schedule.decayedLearningRate(round);
this.logger.debug(`Round ${round}: setting learning rate to ${lr}`);
this.model.optimizer.setLearningRate(lr);
}
}
async fit() {
await this.model.fit(this.data.items, this.data.labels, {
batchSize: this.options.trainOptions.batchSize,
epochs: this.options.trainOptions.epochs,
shuffle: true,
verbose: this.options.tensorflowVerbosity,
});
}
}
exports.ClientWorker = ClientWorker;
const builder = {
async build(options) {
const [data, model] = await Promise.all([
common_1.DataSubset.load(options.dataPathsOrURLs),
tf.loadLayersModel(common_1.PathOrURL.getTfIOHandler(options.modelPathOrURL)),
]);
return Comlink.proxy(new ClientWorker(data, model, options));
},
};
if (worker_threads_2.parentPort === null || worker_threads_1.isMainThread) {
throw new Error("Should only be launched as worker. " +
"Did you accidentally import 'ClientWorker'? " +
"If so, try to change the 'import' to an 'import type'.");
}
Comlink.expose(builder, node_adapter_1.default(worker_threads_2.parentPort));
//# sourceMappingURL=ClientWorker.js.map