UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

224 lines 9.73 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.FLServer = void 0; const tslib_1 = require("tslib"); const assert = require("assert"); const io = tslib_1.__importStar(require("socket.io")); const nanoevents_1 = require("nanoevents"); const common_1 = require("../common"); const Timer_1 = require("./utils/Timer"); const ClientPool_1 = require("./utils/ClientPool"); const Instrumentation_1 = require("./utils/Instrumentation"); /** * Abstract base class for a Federated Learning aggregation server. * * This abstract class defines the basic interface that all servers should * respect; most notably, servers should expose an event emitter that informs * code outside of this class of certain server events. * * The class also does some basic bookkeeping. For instance, it maintains * metadata that clients have sent when authenticating, as well as the set of * clients that are ready to train, and some statistics on how the FL process is * advancing. * * Finally, the class offers a few utility methods that subclasses can use. */ class FLServer { constructor(server, initialWeights, options, logger) { this.options = options; this.logger = logger; /** Event emitter notifying outside code of FL events. */ this.events = nanoevents_1.createNanoEvents(); /** * Mapping of socket IDs to metadata about the client represented by this * socket ID. * * The metadata is collected from the clients, which send it through socket.io * when connecting to the server. Note that clients are not considered "ready" * for training until they have sent their metadata. */ this.clientsMetadata = new Map(); /** Client pool, keeping track of ready clients. */ this.clientPool = new ClientPool_1.ClientPool(); /** Timer used for round summaries. */ this.timer = new Timer_1.Timer(); /** Number of epochs that have been aggregated so far */ this.numberEpochs = 0; /** Number of upload messages received so far. */ this.numberUploads = 0; this.server = new io.Server(server, { // Maximum message size. If messages get bigger than this, the server // will disconnect the client. maxHttpBufferSize: 1e8, // 100 MB }); this.instrumentation = new Instrumentation_1.Instrumentation(options.instrument); this.logger = common_1.createLogger(options.loggerOptions, "Server", "server.log"); this.currentRound = this.getInitialRoundState(initialWeights); this.registerServerListeners(); } terminate() { this.server.close(); // Close socket.io and HTTP servers this.events.events = {}; // Unbind all event listeners } /** * Registers listeners to react to events on `this.server`. * * Subclasses are not meant to call this method (doing so could register * listeners twice), so this is marked as private. */ registerServerListeners() { this.server.on("connection", (socket) => { this.registerSocketListeners(socket); }); } /** * Registers listeners to react to events on the given `socket`. * * Subclasses overriding this method should first call * `super.registerSocketListeners(socket)`. * * @param socket Socket to a client */ registerSocketListeners(socket) { // When a client is ready, do the necessary book-keeping socket.on("ready", (metadata) => { assert.strictEqual(metadata.deltaUpdates, this.expectDeltaUpdates, `Client ${metadata.clientId} sent deltaUpdates: ${metadata.deltaUpdates}, ` + `but ${this.expectDeltaUpdates} was expected.`); this.registerMetadata(socket.id, metadata); this.clientPool.addAvailableClient(socket.id); this.logger.info(`Client ${metadata.clientId} (socket ${socket.id}) ready. There are now ${this.numberClients} ready clients`); if (this.shouldStartTimer()) { this.timer.start(); } }); // If a client disconnects, delete the resources associated to it socket.on("disconnect", (reason) => { // We do not delete clients metadata because it may still be needed, e.g. // to compute an average. this.clientPool.delete(socket.id); this.logDisconnect(socket.id, reason); }); // When a client uploads, do the necessary book-keeping socket.on("upload", (message) => { this.numberUploads += 1; this.numberEpochs += this.getMetadata(socket.id).numberEpochs; this.instrumentation.registerUpload(this.currentRound.roundNumber, message.round); this.clientPool.transition(socket.id, "available"); }); } /** Number of currently connected and ready clients. */ get numberClients() { return this.clientPool.size; } /** * Readonly map of socket IDs to the metadata we keep about this node. */ get metadata() { return this.clientsMetadata; } registerMetadata(id, metadata) { assert(metadata.numberDatapoints > 0); assert(metadata.numberEpochs > 0); if (this.clientsMetadata.has(id)) { const value = JSON.stringify(this.clientsMetadata.get(id)); this.logger.error(`Already have metadata for client ${id}: ${value}`); } this.clientsMetadata.set(id, metadata); } logDisconnect(id, reason) { const metadata = this.clientsMetadata.get(id); if (metadata === undefined) { this.logger.warn(`Client ${id} disconnected before sending ready, for the following reason: ${reason}`); } else { const numClients = this.clientPool.size; const clientId = metadata.clientId; this.logger.warn(`Client ${clientId} disconnected for the following reason: '${reason}'. There are now ${numClients} clients`); } } /** * Gets the client metadata for a socket, or fails if we do not have metadata * for that socket. */ getMetadata(socketId) { const metadata = this.metadata.get(socketId); if (metadata === undefined) { if (this.server.sockets.sockets.has(socketId)) { throw new Error(`Programmer error: socket ${socketId} is connected, ` + `but we do not have metadata for it. ` + `Make sure you only call getMetadata() if you know that the client ` + `is ready to train (i.e. that it has sent a "ready" message)`); } throw new Error(`Programmer error: socket ${socketId} is not connected, and has no metadata. ` + `Make sure you only call getMetadata() for connected clients.`); } return metadata; } /** * Returns the socket with ID `id`. * @throws if no socket with ID `id` is connected */ getSocketWithID(id) { const socket = this.server.sockets.sockets.get(id); if (socket === undefined) { throw new Error(`Could not get socket with id ${id}`); } return socket; } /** * Sends a {@link DownloadMessage} to a list of clients * * @param clients List of clients to send to */ sendDownloadMessage(clients, message) { for (const socket of clients) { this.clientPool.transition(socket.id, "training"); } // Create a temporary socket.io room that we can broadcast to. We use this // approach instead of emitting individually to each client so that the // socket.io library only needs to serialize / compress the message once. // // Note that we treat the returned value of `join` or `leave` as `void`; it // could be `Promise<void>` if we changed the adapter, but for our adapter // it's `void`. const room = "downloading-clients"; clients.forEach((client) => client.join(room)); this.server.to(room).emit("download", message); clients.forEach((client) => client.leave(room)); } /** * Emit the round summary to the coordinator. */ emitRoundSummary(filename, finishedRoundNumber, serializedWeights) { const filepath = this.weightsFilepath(filename); const [ms] = Timer_1.Timer.time(() => common_1.Weights.saveSerializedSync(filepath, serializedWeights)); this.logger.debug(`Saved weights for round ${finishedRoundNumber} in ${ms} ms`); const roundSummary = this.createRoundSummary(finishedRoundNumber, serializedWeights); this.events.emit("roundEnd", roundSummary); } createRoundSummary(roundNumber, weights) { this.instrumentation.registerRoundEnd(); return { roundNumber, weights, numberEpochs: this.numberEpochs, numberUploads: this.numberUploads, numberClientsTraining: this.clientPool.numberTraining, timeSinceStart: this.timer.ms(), instrumentation: this.instrumentation.metrics(), }; } /** * Given a unique name for the weights, returns the full filepath to save * weights to. */ weightsFilepath(name) { return common_1.absolutePath.weights.server(this.options.experimentName, name); } shouldStartTimer() { return (this.numberClients === this.options.minimumNumberClientsForStart && !this.timer.isStarted()); } } exports.FLServer = FLServer; //# sourceMappingURL=FLServer.js.map