federer
Version:
Experiments in asynchronous federated learning and decentralized learning
224 lines • 9.73 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.FLServer = void 0;
const tslib_1 = require("tslib");
const assert = require("assert");
const io = tslib_1.__importStar(require("socket.io"));
const nanoevents_1 = require("nanoevents");
const common_1 = require("../common");
const Timer_1 = require("./utils/Timer");
const ClientPool_1 = require("./utils/ClientPool");
const Instrumentation_1 = require("./utils/Instrumentation");
/**
* Abstract base class for a Federated Learning aggregation server.
*
* This abstract class defines the basic interface that all servers should
* respect; most notably, servers should expose an event emitter that informs
* code outside of this class of certain server events.
*
* The class also does some basic bookkeeping. For instance, it maintains
* metadata that clients have sent when authenticating, as well as the set of
* clients that are ready to train, and some statistics on how the FL process is
* advancing.
*
* Finally, the class offers a few utility methods that subclasses can use.
*/
class FLServer {
constructor(server, initialWeights, options, logger) {
this.options = options;
this.logger = logger;
/** Event emitter notifying outside code of FL events. */
this.events = nanoevents_1.createNanoEvents();
/**
* Mapping of socket IDs to metadata about the client represented by this
* socket ID.
*
* The metadata is collected from the clients, which send it through socket.io
* when connecting to the server. Note that clients are not considered "ready"
* for training until they have sent their metadata.
*/
this.clientsMetadata = new Map();
/** Client pool, keeping track of ready clients. */
this.clientPool = new ClientPool_1.ClientPool();
/** Timer used for round summaries. */
this.timer = new Timer_1.Timer();
/** Number of epochs that have been aggregated so far */
this.numberEpochs = 0;
/** Number of upload messages received so far. */
this.numberUploads = 0;
this.server = new io.Server(server, {
// Maximum message size. If messages get bigger than this, the server
// will disconnect the client.
maxHttpBufferSize: 1e8, // 100 MB
});
this.instrumentation = new Instrumentation_1.Instrumentation(options.instrument);
this.logger = common_1.createLogger(options.loggerOptions, "Server", "server.log");
this.currentRound = this.getInitialRoundState(initialWeights);
this.registerServerListeners();
}
terminate() {
this.server.close(); // Close socket.io and HTTP servers
this.events.events = {}; // Unbind all event listeners
}
/**
* Registers listeners to react to events on `this.server`.
*
* Subclasses are not meant to call this method (doing so could register
* listeners twice), so this is marked as private.
*/
registerServerListeners() {
this.server.on("connection", (socket) => {
this.registerSocketListeners(socket);
});
}
/**
* Registers listeners to react to events on the given `socket`.
*
* Subclasses overriding this method should first call
* `super.registerSocketListeners(socket)`.
*
* @param socket Socket to a client
*/
registerSocketListeners(socket) {
// When a client is ready, do the necessary book-keeping
socket.on("ready", (metadata) => {
assert.strictEqual(metadata.deltaUpdates, this.expectDeltaUpdates, `Client ${metadata.clientId} sent deltaUpdates: ${metadata.deltaUpdates}, ` +
`but ${this.expectDeltaUpdates} was expected.`);
this.registerMetadata(socket.id, metadata);
this.clientPool.addAvailableClient(socket.id);
this.logger.info(`Client ${metadata.clientId} (socket ${socket.id}) ready. There are now ${this.numberClients} ready clients`);
if (this.shouldStartTimer()) {
this.timer.start();
}
});
// If a client disconnects, delete the resources associated to it
socket.on("disconnect", (reason) => {
// We do not delete clients metadata because it may still be needed, e.g.
// to compute an average.
this.clientPool.delete(socket.id);
this.logDisconnect(socket.id, reason);
});
// When a client uploads, do the necessary book-keeping
socket.on("upload", (message) => {
this.numberUploads += 1;
this.numberEpochs += this.getMetadata(socket.id).numberEpochs;
this.instrumentation.registerUpload(this.currentRound.roundNumber, message.round);
this.clientPool.transition(socket.id, "available");
});
}
/** Number of currently connected and ready clients. */
get numberClients() {
return this.clientPool.size;
}
/**
* Readonly map of socket IDs to the metadata we keep about this node.
*/
get metadata() {
return this.clientsMetadata;
}
registerMetadata(id, metadata) {
assert(metadata.numberDatapoints > 0);
assert(metadata.numberEpochs > 0);
if (this.clientsMetadata.has(id)) {
const value = JSON.stringify(this.clientsMetadata.get(id));
this.logger.error(`Already have metadata for client ${id}: ${value}`);
}
this.clientsMetadata.set(id, metadata);
}
logDisconnect(id, reason) {
const metadata = this.clientsMetadata.get(id);
if (metadata === undefined) {
this.logger.warn(`Client ${id} disconnected before sending ready, for the following reason: ${reason}`);
}
else {
const numClients = this.clientPool.size;
const clientId = metadata.clientId;
this.logger.warn(`Client ${clientId} disconnected for the following reason: '${reason}'. There are now ${numClients} clients`);
}
}
/**
* Gets the client metadata for a socket, or fails if we do not have metadata
* for that socket.
*/
getMetadata(socketId) {
const metadata = this.metadata.get(socketId);
if (metadata === undefined) {
if (this.server.sockets.sockets.has(socketId)) {
throw new Error(`Programmer error: socket ${socketId} is connected, ` +
`but we do not have metadata for it. ` +
`Make sure you only call getMetadata() if you know that the client ` +
`is ready to train (i.e. that it has sent a "ready" message)`);
}
throw new Error(`Programmer error: socket ${socketId} is not connected, and has no metadata. ` +
`Make sure you only call getMetadata() for connected clients.`);
}
return metadata;
}
/**
* Returns the socket with ID `id`.
* @throws if no socket with ID `id` is connected
*/
getSocketWithID(id) {
const socket = this.server.sockets.sockets.get(id);
if (socket === undefined) {
throw new Error(`Could not get socket with id ${id}`);
}
return socket;
}
/**
* Sends a {@link DownloadMessage} to a list of clients
*
* @param clients List of clients to send to
*/
sendDownloadMessage(clients, message) {
for (const socket of clients) {
this.clientPool.transition(socket.id, "training");
}
// Create a temporary socket.io room that we can broadcast to. We use this
// approach instead of emitting individually to each client so that the
// socket.io library only needs to serialize / compress the message once.
//
// Note that we treat the returned value of `join` or `leave` as `void`; it
// could be `Promise<void>` if we changed the adapter, but for our adapter
// it's `void`.
const room = "downloading-clients";
clients.forEach((client) => client.join(room));
this.server.to(room).emit("download", message);
clients.forEach((client) => client.leave(room));
}
/**
* Emit the round summary to the coordinator.
*/
emitRoundSummary(filename, finishedRoundNumber, serializedWeights) {
const filepath = this.weightsFilepath(filename);
const [ms] = Timer_1.Timer.time(() => common_1.Weights.saveSerializedSync(filepath, serializedWeights));
this.logger.debug(`Saved weights for round ${finishedRoundNumber} in ${ms} ms`);
const roundSummary = this.createRoundSummary(finishedRoundNumber, serializedWeights);
this.events.emit("roundEnd", roundSummary);
}
createRoundSummary(roundNumber, weights) {
this.instrumentation.registerRoundEnd();
return {
roundNumber,
weights,
numberEpochs: this.numberEpochs,
numberUploads: this.numberUploads,
numberClientsTraining: this.clientPool.numberTraining,
timeSinceStart: this.timer.ms(),
instrumentation: this.instrumentation.metrics(),
};
}
/**
* Given a unique name for the weights, returns the full filepath to save
* weights to.
*/
weightsFilepath(name) {
return common_1.absolutePath.weights.server(this.options.experimentName, name);
}
shouldStartTimer() {
return (this.numberClients === this.options.minimumNumberClientsForStart &&
!this.timer.isStarted());
}
}
exports.FLServer = FLServer;
//# sourceMappingURL=FLServer.js.map