federer
Version:
Experiments in asynchronous federated learning and decentralized learning
97 lines • 3.5 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.readLabels = exports.readLabelFile = exports.readImages = exports.readImageFile = exports.readRawData = void 0;
const tslib_1 = require("tslib");
const assert = require("assert");
const tf = tslib_1.__importStar(require("@tensorflow/tfjs-node"));
const common_1 = require("../../../../common");
const coordinator_1 = require("../../../../coordinator");
const download_1 = require("./download");
/**
* Read binary MNIST data, as downloaded from the Internet.
*/
async function readRawData(dataset, environment) {
const filepaths = download_1.getFilepaths(dataset);
const [trainImages, trainLabels, testImages, testLabels] = await Promise.all([
readImageFile(filepaths.train.items, environment),
readLabelFile(filepaths.train.labels, environment),
readImageFile(filepaths.test.items, environment),
readLabelFile(filepaths.test.labels, environment),
]);
return new common_1.Dataset({
train: new common_1.DataSubset({
items: trainImages,
labels: trainLabels,
}),
test: new common_1.DataSubset({
items: testImages,
labels: testLabels,
}),
});
}
exports.readRawData = readRawData;
var MagicNumber;
(function (MagicNumber) {
MagicNumber[MagicNumber["Image"] = 2051] = "Image";
MagicNumber[MagicNumber["Label"] = 2049] = "Label";
})(MagicNumber || (MagicNumber = {}));
const IMAGE_HEADER_BYTES = 16;
const LABEL_HEADER_BYTES = 8;
async function readImageFile(file, environment) {
const buffer = await coordinator_1.readFileInEnvironment(file, environment);
return readImages(buffer);
}
exports.readImageFile = readImageFile;
function readImages(buffer) {
const header = readImagesHeader(buffer);
return readImagesBody(buffer, header);
}
exports.readImages = readImages;
function readImagesHeader(buffer) {
const magicNumber = buffer.readUInt32BE(0);
if (magicNumber !== MagicNumber.Image) {
throw new Error(`Invalid magic number ${magicNumber}; expected ${MagicNumber.Image}`);
}
return {
magicNumber,
length: buffer.readUInt32BE(4),
height: buffer.readUInt32BE(8),
width: buffer.readUInt32BE(12),
};
}
function readImagesBody(buffer, header) {
const array = new Float32Array(buffer.slice(IMAGE_HEADER_BYTES));
const imagesShape = [
header.length,
header.height,
header.width,
1,
];
return tf.tensor1d(array, "int32").reshape(imagesShape);
}
async function readLabelFile(path, environment) {
const buffer = await coordinator_1.readFileInEnvironment(path, environment);
return readLabels(buffer);
}
exports.readLabelFile = readLabelFile;
function readLabels(buffer) {
const header = readLabelsHeader(buffer);
return readLabelsBody(buffer, header);
}
exports.readLabels = readLabels;
function readLabelsHeader(buffer) {
const magicNumber = buffer.readUInt32BE(0);
if (magicNumber !== MagicNumber.Label) {
throw new Error(`Invalid magic number ${magicNumber}; expected ${MagicNumber.Label}`);
}
return {
magicNumber,
length: buffer.readUInt32BE(4),
};
}
function readLabelsBody(buffer, header) {
const array = new Int32Array(buffer.slice(LABEL_HEADER_BYTES));
assert.strictEqual(array.length, header.length);
return tf.tensor1d(array, "int32");
}
//# sourceMappingURL=read-binary.js.map