UNPKG

federer

Version:

Experiments in asynchronous federated learning and decentralized learning

97 lines 3.5 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.readLabels = exports.readLabelFile = exports.readImages = exports.readImageFile = exports.readRawData = void 0; const tslib_1 = require("tslib"); const assert = require("assert"); const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node")); const common_1 = require("../../../../common"); const coordinator_1 = require("../../../../coordinator"); const download_1 = require("./download"); /** * Read binary MNIST data, as downloaded from the Internet. */ async function readRawData(dataset, environment) { const filepaths = download_1.getFilepaths(dataset); const [trainImages, trainLabels, testImages, testLabels] = await Promise.all([ readImageFile(filepaths.train.items, environment), readLabelFile(filepaths.train.labels, environment), readImageFile(filepaths.test.items, environment), readLabelFile(filepaths.test.labels, environment), ]); return new common_1.Dataset({ train: new common_1.DataSubset({ items: trainImages, labels: trainLabels, }), test: new common_1.DataSubset({ items: testImages, labels: testLabels, }), }); } exports.readRawData = readRawData; var MagicNumber; (function (MagicNumber) { MagicNumber[MagicNumber["Image"] = 2051] = "Image"; MagicNumber[MagicNumber["Label"] = 2049] = "Label"; })(MagicNumber || (MagicNumber = {})); const IMAGE_HEADER_BYTES = 16; const LABEL_HEADER_BYTES = 8; async function readImageFile(file, environment) { const buffer = await coordinator_1.readFileInEnvironment(file, environment); return readImages(buffer); } exports.readImageFile = readImageFile; function readImages(buffer) { const header = readImagesHeader(buffer); return readImagesBody(buffer, header); } exports.readImages = readImages; function readImagesHeader(buffer) { const magicNumber = buffer.readUInt32BE(0); if (magicNumber !== MagicNumber.Image) { throw new Error(`Invalid magic number ${magicNumber}; expected ${MagicNumber.Image}`); } return { magicNumber, length: buffer.readUInt32BE(4), height: buffer.readUInt32BE(8), width: buffer.readUInt32BE(12), }; } function readImagesBody(buffer, header) { const array = new Float32Array(buffer.slice(IMAGE_HEADER_BYTES)); const imagesShape = [ header.length, header.height, header.width, 1, ]; return tf.tensor1d(array, "int32").reshape(imagesShape); } async function readLabelFile(path, environment) { const buffer = await coordinator_1.readFileInEnvironment(path, environment); return readLabels(buffer); } exports.readLabelFile = readLabelFile; function readLabels(buffer) { const header = readLabelsHeader(buffer); return readLabelsBody(buffer, header); } exports.readLabels = readLabels; function readLabelsHeader(buffer) { const magicNumber = buffer.readUInt32BE(0); if (magicNumber !== MagicNumber.Label) { throw new Error(`Invalid magic number ${magicNumber}; expected ${MagicNumber.Label}`); } return { magicNumber, length: buffer.readUInt32BE(4), }; } function readLabelsBody(buffer, header) { const array = new Int32Array(buffer.slice(LABEL_HEADER_BYTES)); assert.strictEqual(array.length, header.length); return tf.tensor1d(array, "int32"); } //# sourceMappingURL=read-binary.js.map