federer
Version:
Experiments in asynchronous federated learning and decentralized learning
97 lines • 5.19 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.shardIntoUnbalancedBatchesPerBatch = exports.shardIntoSkedwedBatches = exports.shardIntoSortedLabelBatches = exports.shardByLabel = void 0;
const assert = require("assert");
const shuffle_seed_1 = require("shuffle-seed");
const common_1 = require("../../common");
const datasplits_1 = require("./datasplits");
const filters_1 = require("./filters");
/**
* Split a dataset into shards, where each shard holds a different label.
*
* @param numShards Number of shards to split a dataset into
* @returns A function that takes a {@link DataSubset}, and returns a list of
* shards, where each shard holds a different label.
*/
function shardByLabel(numShards) {
return async (dataSubset) => {
const result = [];
for (let i = 0; i < numShards; ++i) {
result.push(filters_1.filterByLabels(i)(dataSubset));
}
return Promise.all(result);
};
}
exports.shardByLabel = shardByLabel;
async function sortByLabel(numberLabelClasses, data) {
const splitByLabel = await shardByLabel(numberLabelClasses)(data);
const concatenated = splitByLabel.reduce((acc, shard) => acc.concat(shard));
return concatenated;
}
/**
* Splits a dataset into shards, by sorting the dataset by label, splitting into
* `numberBatches` batches, and then combining `numberBatchesPerShard` batches
* per shard.
*
* @param numberLabelClasses Number of different labels in the dataset
* @param numberBatches Number of batches to produce
* @param numberBatchesPerShard Number of batches to place in a shard
* @returns A function that takes a {@link DataSubset}, and returns a list of
* shards, where each shard is as described above.
*/
function shardIntoSortedLabelBatches(numberLabelClasses, numberBatches, numberBatchesPerShard) {
return async (data) => {
// It might be easier to call `tf.sort`, but that is not implemented
// yet in TensorFlow.js. Instead, we split by label, concat, and then split
// into batches, which are then shuffled and combined into shards.
const concatenated = await sortByLabel(numberLabelClasses, data);
const size = data.countNumberDatapoints();
const batchSize = Math.floor(size / numberBatches);
const batchSizes = Array.from({ length: numberBatches }).fill(batchSize);
batchSizes[batchSizes.length - 1] = batchSize + (size % numberBatches);
// console.log(`Split ${size} datapoints into batches of size ${batchSizes}`);
const itemBatches = concatenated.items.split(batchSizes);
const labelBatches = concatenated.labels.split(batchSizes);
assert.strictEqual(itemBatches.length, labelBatches.length);
const batches = common_1.zip(itemBatches, labelBatches).map(([items, labels]) => new common_1.DataSubset({ items, labels }));
return common_1.chunk(shuffle_seed_1.shuffle(batches, ""), numberBatchesPerShard).map((chnk) => chnk.reduce((acc, batch) => acc.concat(batch)));
};
}
exports.shardIntoSortedLabelBatches = shardIntoSortedLabelBatches;
function shardIntoSkedwedBatches(numberLabelClasses, numberShards, skewFactorS) {
return async (data) => {
const numberBatchesPerShard = 1;
const concatenated = await sortByLabel(numberLabelClasses, data);
// TODO: refactorable code by adding two more params in the function call
// splitfn and splitFnParams
// FROM HERE
const size = data.countNumberDatapoints();
// since we want 1 batch per shard we are passing numberShards directly
// to the function call and returning batchSizes
const batchSizes = datasplits_1.zipfSplits(size, numberShards, skewFactorS);
// TO HERE
const itemBatches = concatenated.items.split(batchSizes);
const labelBatches = concatenated.labels.split(batchSizes);
assert.strictEqual(itemBatches.length, labelBatches.length);
const batches = common_1.zip(itemBatches, labelBatches).map(([items, labels]) => new common_1.DataSubset({ items, labels }));
return common_1.chunk(shuffle_seed_1.shuffle(batches, ""), numberBatchesPerShard).map((chnk) => chnk.reduce((acc, batch) => acc.concat(batch)));
};
}
exports.shardIntoSkedwedBatches = shardIntoSkedwedBatches;
/**
* Combines sortedLables and skewed functions to create unbalanced
* shards with lowerbound on samples per shard.
*
* @returns a shard function
*/
function shardIntoUnbalancedBatchesPerBatch(sortedLabelBatches, skewedBatches, sortedLabelSplit) {
return async (data) => {
const splits = data.split(sortedLabelSplit);
const sortedShards = await sortedLabelBatches(splits[0]);
const skewedShards = await skewedBatches(splits[1]);
assert(sortedShards.length == skewedShards.length, Error("Number of shards generated by sorted function and skewed function should be same"));
return common_1.zip(sortedShards, skewedShards).map((chnk) => chnk.reduce((acc, batch) => acc.concat(batch)));
};
}
exports.shardIntoUnbalancedBatchesPerBatch = shardIntoUnbalancedBatchesPerBatch;
//# sourceMappingURL=shard.js.map