federer
Version:
Experiments in asynchronous federated learning and decentralized learning
158 lines • 6.63 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.Experiment = void 0;
const common_1 = require("../common");
const cli_1 = require("./cli");
const _1 = require(".");
const DELTA_UPDATES = new Set([
"BatchedFedAsync",
"FedSemiSync",
]);
class Experiment {
constructor(network, options, data, logger) {
this.network = network;
this.options = options;
this.data = data;
this.logger = logger;
this.clientDelays = _1.getClientDelays(options.clientDelays ?? { type: "none" }, cli_1.CoordinatorCLIOptions.get("number-clients"));
this.clientLoggerOptions = _1.getLoggerOptions("client", this.options.logging);
this.modelPathOrURL = this.getPathOrUrl(this.data.modelFilePath);
const unsubscribe = network.server.events.on("stopped", () => {
this.logger.debug("server stopped");
unsubscribe();
});
network.clients.forEach((client, i) => {
const unsubscribe = client.events.on("stopped", () => {
this.logger.debug(`client ${i} stopped`);
unsubscribe();
});
});
}
getPathOrUrl(filepath) {
const host = cli_1.CoordinatorCLIOptions.get("fileserver-host");
if (host !== undefined) {
return encodeURI(`http://${host}/${common_1.relative(filepath)}`);
}
else {
return filepath;
}
}
async start() {
await this.startServer();
await this.startClients();
}
async stop() {
this.network.server.stop();
this.network.ipc.clients.emit("stop");
await Promise.all([
this.network.server.stopped(),
...this.network.clients.map((client) => client.stopped()),
]);
this.logger.debug("All nodes have returned 'stopped'");
}
async startServer() {
this.logger.info(`Starting ${this.options.serverOptions.server} server`);
const server = this.network.server;
this.logger.debug("Waiting to receive 'ready' from server");
await server.ready();
this.logger.debug("Received 'ready' from server, sending 'start'");
server.start(this.getServerOptions());
this.logger.debug("Waiting to receive 'started' from server");
await server.started();
this.logger.debug(`Started ${this.options.serverOptions.server} server`);
}
async startClients() {
await Promise.all(this.network.clients.map(async (client, i) => {
this.logger.debug(`Waiting to receive 'ready' from client ${i}`);
await client.ready();
this.logger.debug(`Received 'ready' from client ${i}, sending 'start'`);
client.start(this.getClientOptions(i));
this.logger.debug(`Waiting to receive 'started' from client ${i}`);
await client.started();
this.logger.debug(`Client ${i} started`);
}));
}
/** Returns the options for client with a given id. */
getClientOptions(id) {
const dataFilePaths = this.data.preprocessed.clientTrainFiles[id];
return {
id,
deltaUpdates: DELTA_UPDATES.has(this.options.serverOptions.server),
dataPathsOrURLs: {
items: this.getPathOrUrl(dataFilePaths.items),
labels: this.getPathOrUrl(dataFilePaths.labels),
},
modelPathOrURL: this.modelPathOrURL,
replyDelay: this.clientDelays[id],
trainOptions: this.options.trainOptions,
tensorflowVerbosity: this.options.tensorflowVerbosity ?? 0,
loggerOptions: this.clientLoggerOptions,
};
}
getServerOptions() {
const loggerOptions = _1.getLoggerOptions("server", this.options.logging);
const numberClients = cli_1.CoordinatorCLIOptions.get("number-clients");
const base = {
modelPathOrURL: this.modelPathOrURL,
experimentName: this.data.experimentName,
minimumNumberClientsForStart: numberClients,
loggerOptions,
instrument: {
memoryUsage: this.options.instrument?.memoryUsage ?? false,
uploadStaleness: this.options.instrument?.uploadStaleness ?? false,
},
};
switch (this.options.serverOptions.server) {
case "FedAvg":
return {
...base,
server: "FedAvg",
numberClients,
fractionOfClientsPerRound: this.options.serverOptions.fractionOfClientsPerRound,
};
case "LiFedAvg":
return {
...base,
server: "LiFedAvg",
numberClients,
fractionOfClientsPerRound: this.options.serverOptions.fractionOfClientsPerRound,
};
case "FedAsync":
return {
...base,
server: "FedAsync",
epochDelay: this.options.serverOptions.epochDelay,
alpha: this.options.serverOptions.alpha,
staleness: this.options.serverOptions.staleness,
};
case "BatchedFedAsync":
return {
...base,
server: "BatchedFedAsync",
numberClientsPerRound: this.options.serverOptions.numberClientsPerRound,
roundEndFraction: this.options.serverOptions.roundEndFraction,
alpha: this.options.serverOptions.alpha,
};
case "FedSemiSync":
return {
...base,
server: "FedSemiSync",
numberClientsPerRound: this.options.serverOptions.numberClientsPerRound,
roundEndFraction: this.options.serverOptions.roundEndFraction,
a: this.options.serverOptions.a,
alpha: this.options.serverOptions.alpha,
};
case "FedCRDT":
return {
...base,
server: "FedCRDT",
numberClientsPerRound: this.options.serverOptions.numberClientsPerRound,
roundEndFraction: this.options.serverOptions.roundEndFraction,
a: this.options.serverOptions.a,
alpha: this.options.serverOptions.alpha,
};
}
}
}
exports.Experiment = Experiment;
//# sourceMappingURL=Experiment.js.map