UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

97 lines 5.19 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.shardIntoUnbalancedBatchesPerBatch = exports.shardIntoSkedwedBatches = exports.shardIntoSortedLabelBatches = exports.shardByLabel = void 0; const assert = require("assert"); const shuffle_seed_1 = require("shuffle-seed"); const common_1 = require("../../common"); const datasplits_1 = require("./datasplits"); const filters_1 = require("./filters"); /** * Split a dataset into shards, where each shard holds a different label. * * @param numShards Number of shards to split a dataset into * @returns A function that takes a {@link DataSubset}, and returns a list of * shards, where each shard holds a different label. */ function shardByLabel(numShards) { return async (dataSubset) => { const result = []; for (let i = 0; i < numShards; ++i) { result.push(filters_1.filterByLabels(i)(dataSubset)); } return Promise.all(result); }; } exports.shardByLabel = shardByLabel; async function sortByLabel(numberLabelClasses, data) { const splitByLabel = await shardByLabel(numberLabelClasses)(data); const concatenated = splitByLabel.reduce((acc, shard) => acc.concat(shard)); return concatenated; } /** * Splits a dataset into shards, by sorting the dataset by label, splitting into * `numberBatches` batches, and then combining `numberBatchesPerShard` batches * per shard. * * @param numberLabelClasses Number of different labels in the dataset * @param numberBatches Number of batches to produce * @param numberBatchesPerShard Number of batches to place in a shard * @returns A function that takes a {@link DataSubset}, and returns a list of * shards, where each shard is as described above. */ function shardIntoSortedLabelBatches(numberLabelClasses, numberBatches, numberBatchesPerShard) { return async (data) => { // It might be easier to call `tf.sort`, but that is not implemented // yet in TensorFlow.js. Instead, we split by label, concat, and then split // into batches, which are then shuffled and combined into shards. const concatenated = await sortByLabel(numberLabelClasses, data); const size = data.countNumberDatapoints(); const batchSize = Math.floor(size / numberBatches); const batchSizes = Array.from({ length: numberBatches }).fill(batchSize); batchSizes[batchSizes.length - 1] = batchSize + (size % numberBatches); // console.log(`Split ${size} datapoints into batches of size ${batchSizes}`); const itemBatches = concatenated.items.split(batchSizes); const labelBatches = concatenated.labels.split(batchSizes); assert.strictEqual(itemBatches.length, labelBatches.length); const batches = common_1.zip(itemBatches, labelBatches).map(([items, labels]) => new common_1.DataSubset({ items, labels })); return common_1.chunk(shuffle_seed_1.shuffle(batches, ""), numberBatchesPerShard).map((chnk) => chnk.reduce((acc, batch) => acc.concat(batch))); }; } exports.shardIntoSortedLabelBatches = shardIntoSortedLabelBatches; function shardIntoSkedwedBatches(numberLabelClasses, numberShards, skewFactorS) { return async (data) => { const numberBatchesPerShard = 1; const concatenated = await sortByLabel(numberLabelClasses, data); // TODO: refactorable code by adding two more params in the function call // splitfn and splitFnParams // FROM HERE const size = data.countNumberDatapoints(); // since we want 1 batch per shard we are passing numberShards directly // to the function call and returning batchSizes const batchSizes = datasplits_1.zipfSplits(size, numberShards, skewFactorS); // TO HERE const itemBatches = concatenated.items.split(batchSizes); const labelBatches = concatenated.labels.split(batchSizes); assert.strictEqual(itemBatches.length, labelBatches.length); const batches = common_1.zip(itemBatches, labelBatches).map(([items, labels]) => new common_1.DataSubset({ items, labels })); return common_1.chunk(shuffle_seed_1.shuffle(batches, ""), numberBatchesPerShard).map((chnk) => chnk.reduce((acc, batch) => acc.concat(batch))); }; } exports.shardIntoSkedwedBatches = shardIntoSkedwedBatches; /** * Combines sortedLables and skewed functions to create unbalanced * shards with lowerbound on samples per shard. * * @returns a shard function */ function shardIntoUnbalancedBatchesPerBatch(sortedLabelBatches, skewedBatches, sortedLabelSplit) { return async (data) => { const splits = data.split(sortedLabelSplit); const sortedShards = await sortedLabelBatches(splits[0]); const skewedShards = await skewedBatches(splits[1]); assert(sortedShards.length == skewedShards.length, Error("Number of shards generated by sorted function and skewed function should be same")); return common_1.zip(sortedShards, skewedShards).map((chnk) => chnk.reduce((acc, batch) => acc.concat(batch))); }; } exports.shardIntoUnbalancedBatchesPerBatch = shardIntoUnbalancedBatchesPerBatch; //# sourceMappingURL=shard.js.map