federer
Version:
Experiments in asynchronous federated learning and decentralized learning
32 lines • 1.29 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.dataDistributionStats = void 0;
const tslib_1 = require("tslib");
const assert = require("assert");
const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node"));
/**
* Computes statistics on the distribution of data that is produced by a
* preprocessing pipeline.
*
* @param clientLabels Labels in the sharded client datasets. These should be in
* numerical format, not one-hot encoded.
* @param numberLabelClasses Number of labels that exist in total
*/
function dataDistributionStats(clientLabels, numberLabelClasses) {
// For now, this is the only statistic we compute. This can be extended to
// compute more statistics.
return {
distributionMatrix: labelDistributionMatrix(clientLabels, numberLabelClasses),
};
}
exports.dataDistributionStats = dataDistributionStats;
function labelDistributionMatrix(clientLabels, numberLabelClasses) {
const bincounts = clientLabels.map((labels) => tf.bincount(labels, [], numberLabelClasses));
const distribution = tf.stack(bincounts);
assert.deepStrictEqual(distribution.shape, [
clientLabels.length,
numberLabelClasses,
]);
return distribution;
}
//# sourceMappingURL=distribution-stats.js.map