federer
Version:
Experiments in asynchronous federated learning and decentralized learning
48 lines • 2.14 kB
TypeScript
import * as tf from "@tensorflow/tfjs-node";
import { Logger } from "winston";
import * as Comlink from "comlink";
import { ClientStartOptions, DataSubset, SerializedWeights } from "../common";
import { LearningRateSchedule } from "./decay";
/**
* A {@link ClientWorker} is a process that "does the actual work". The
* {@link Client} is responsible for anything related to networking, and the
* {@link ClientWorker} does anything CPU-intensive, like training the model.
*
* The reason for this split is to keep the main thread of the client free, so
* that it can respond to socket.io ping messages, for instance. If we ran the
* client in a browser, it would also be very important to keep the main thread
* free for things related to UI.
*
* The {@link Client} and {@link ClientWorker} communicate using
* {@link https://github.com/GoogleChromeLabs/comlink | comlink}, which is a
* wrapper API around `Worker.postMessage`.
*
* @see {@link https://maximekjaer.github.io/federer/docs/advanced/workers/}
*/
export declare class ClientWorker {
protected readonly data: DataSubset;
protected readonly model: tf.LayersModel;
protected readonly options: ClientStartOptions;
protected readonly logger: Logger;
/**
* Learning rate schedule, used for optionally decaying the learning rate as
* rounds progress.
*/
protected readonly schedule: LearningRateSchedule | undefined;
constructor(data: DataSubset, model: tf.LayersModel, options: ClientStartOptions);
countNumberDataPoints(): number;
train(weights: SerializedWeights, round: number): Promise<SerializedWeights>;
private doTrain;
private maybeSetLearningRate;
protected fit(): Promise<void>;
}
/**
* Interface of a builder of {@link ClientWorker}. This builder takes
* {@link ClientStartOptions} and asynchronously builds the {@link ClientWorker}
* on the worker thread, and returns a proxy to the worker instance to the main
* thread.
*/
export interface ClientWorkerBuilder {
build(options: ClientStartOptions): Promise<ClientWorker & Comlink.ProxyMarked>;
}
//# sourceMappingURL=ClientWorker.d.ts.map