UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

32 lines 1.29 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.dataDistributionStats = void 0; const tslib_1 = require("tslib"); const assert = require("assert"); const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node")); /** * Computes statistics on the distribution of data that is produced by a * preprocessing pipeline. * * @param clientLabels Labels in the sharded client datasets. These should be in * numerical format, not one-hot encoded. * @param numberLabelClasses Number of labels that exist in total */ function dataDistributionStats(clientLabels, numberLabelClasses) { // For now, this is the only statistic we compute. This can be extended to // compute more statistics. return { distributionMatrix: labelDistributionMatrix(clientLabels, numberLabelClasses), }; } exports.dataDistributionStats = dataDistributionStats; function labelDistributionMatrix(clientLabels, numberLabelClasses) { const bincounts = clientLabels.map((labels) => tf.bincount(labels, [], numberLabelClasses)); const distribution = tf.stack(bincounts); assert.deepStrictEqual(distribution.shape, [ clientLabels.length, numberLabelClasses, ]); return distribution; } //# sourceMappingURL=distribution-stats.js.map