tensorflow-helpers
Version:
Helper functions to use tensorflow in nodejs for transfer learning, image classification, and more
255 lines (254 loc) • 9.99 kB
JavaScript
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
var desc = Object.getOwnPropertyDescriptor(m, k);
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
desc = { enumerable: true, get: function() { return m[k]; } };
}
Object.defineProperty(o, k2, desc);
}) : (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
o[k2] = m[k];
}));
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
Object.defineProperty(o, "default", { enumerable: true, value: v });
}) : function(o, v) {
o["default"] = v;
});
var __importStar = (this && this.__importStar) || (function () {
var ownKeys = function(o) {
ownKeys = Object.getOwnPropertyNames || function (o) {
var ar = [];
for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
return ar;
};
return ownKeys(o);
};
return function (mod) {
if (mod && mod.__esModule) return mod;
var result = {};
if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
__setModuleDefault(result, mod);
return result;
};
})();
Object.defineProperty(exports, "__esModule", { value: true });
exports.PreTrainedImageModels = void 0;
exports.saveModel = saveModel;
exports.loadGraphModel = loadGraphModel;
exports.loadLayersModel = loadLayersModel;
exports.cachedLoadGraphModel = cachedLoadGraphModel;
exports.cachedLoadLayersModel = cachedLoadLayersModel;
exports.loadImageModel = loadImageModel;
const tf = __importStar(require("@tensorflow/tfjs-node"));
const fs_1 = require("fs");
const promises_1 = require("fs/promises");
const path_1 = require("path");
const image_1 = require("./image");
const file_1 = require("./file");
const tensor_1 = require("./tensor");
const classifier_utils_1 = require("./classifier-utils");
const spatial_utils_1 = require("./spatial-utils");
var image_model_1 = require("./image-model");
Object.defineProperty(exports, "PreTrainedImageModels", { enumerable: true, get: function () { return image_model_1.PreTrainedImageModels; } });
async function saveModel(options) {
let { dir, model, classNames } = options;
return await model.save({
async save(modelArtifact) {
if (classNames) {
modelArtifact.classNames = classNames;
}
let weights = modelArtifact.weightData;
if (!weights) {
throw new Error('missing weightData');
}
if (!Array.isArray(weights)) {
weights = [weights];
}
(0, fs_1.mkdirSync)(dir, { recursive: true });
(0, promises_1.writeFile)((0, path_1.join)(dir, 'model.json'), JSON.stringify(modelArtifact));
for (let i = 0; i < weights.length; i++) {
(0, promises_1.writeFile)((0, path_1.join)(dir, `weight-${i}.bin`), Buffer.from(weights[i]));
}
return {
modelArtifactsInfo: {
dateSaved: new Date(),
modelTopologyType: 'JSON',
},
};
},
});
}
async function loadGraphModel(options) {
let { dir, classNames } = options;
let model = await tf.loadGraphModel({
async load() {
let buffer = await (0, promises_1.readFile)((0, path_1.join)(dir, 'model.json'));
let modelArtifact = JSON.parse(buffer.toString());
classNames = (0, classifier_utils_1.checkClassNames)(modelArtifact, classNames);
let weights = modelArtifact.weightData;
if (!weights) {
throw new Error('missing weightData');
}
if (!Array.isArray(weights)) {
weights = [weights];
}
for (let i = 0; i < weights.length; i++) {
weights[i] = await loadWeightData((0, path_1.join)(dir, `weight-${i}.bin`));
}
return modelArtifact;
},
});
return (0, classifier_utils_1.attachClassNames)(model, classNames);
}
async function loadLayersModel(options) {
let { dir, classNames } = options;
let model = await tf.loadLayersModel({
async load() {
let buffer = await (0, promises_1.readFile)((0, path_1.join)(dir, 'model.json'));
let modelArtifact = JSON.parse(buffer.toString());
classNames = (0, classifier_utils_1.checkClassNames)(modelArtifact, classNames);
let weights = modelArtifact.weightData;
if (!weights) {
throw new Error('missing weightData');
}
if (!Array.isArray(weights)) {
modelArtifact.weightData = await loadWeightData((0, path_1.join)(dir, `weight-0.bin`));
return modelArtifact;
}
for (let i = 0; i < weights.length; i++) {
weights[i] = await loadWeightData((0, path_1.join)(dir, `weight-${i}.bin`));
}
return modelArtifact;
},
});
if (classNames) {
let classCount = (0, classifier_utils_1.getClassCount)(model.outputShape);
if (classCount != classNames.length) {
throw new Error('number of classes mismatch, expected: ' +
classNames.length +
', got: ' +
classCount);
}
}
return (0, classifier_utils_1.attachClassNames)(model, classNames);
}
async function loadWeightData(file) {
let buffer = await (0, promises_1.readFile)(file);
return new Uint8Array(buffer);
}
async function cachedLoadGraphModel(options) {
let { url: modelUrl, dir: modelDir, classNames } = options;
if ((0, fs_1.existsSync)(modelDir)) {
return await loadGraphModel(options);
}
let model = await tf.loadGraphModel(modelUrl, { fromTFHub: true });
await saveModel({ model, dir: modelDir, classNames });
return (0, classifier_utils_1.attachClassNames)(model, classNames);
}
async function cachedLoadLayersModel(options) {
let { url: modelUrl, dir: modelDir, classNames } = options;
if ((0, fs_1.existsSync)(modelDir)) {
return await loadLayersModel(options);
}
let model = await tf.loadLayersModel(modelUrl, { fromTFHub: true });
await saveModel({ model, dir: modelDir, classNames });
return (0, classifier_utils_1.attachClassNames)(model, classNames);
}
async function loadImageModel(options) {
let { spec, dir, aspectRatio, cache } = options;
let { url, width, height, channels } = spec;
let model = await cachedLoadGraphModel({
url,
dir,
});
async function loadImageCropped(file, options) {
let imageTensor = await (0, image_1.loadImageFile)(file, {
channels,
expandAnimations: options?.expandAnimations,
crop: {
width,
height,
aspectRatio,
},
});
return imageTensor;
}
let fileEmbeddingCache = cache
? new Map()
: null;
function checkCache(file_or_filename) {
if (!fileEmbeddingCache || !(0, file_1.isContentHash)(file_or_filename))
return;
let filename = (0, path_1.basename)(file_or_filename);
let embedding = fileEmbeddingCache.get(filename);
if (embedding)
return embedding;
let values = typeof cache == 'object' ? cache.get(filename) : undefined;
if (!values)
return;
embedding = tf.tensor([values]);
fileEmbeddingCache.set(filename, embedding);
return embedding;
}
async function saveCache(file, embedding) {
let filename = (0, path_1.basename)(file);
fileEmbeddingCache.set(filename, embedding);
if (typeof cache == 'object') {
let values = Array.from(await embedding.data());
cache.set(filename, values);
}
}
async function imageFileToEmbedding(file, options) {
let embedding = checkCache(file);
if (embedding)
return embedding;
let content = await (0, promises_1.readFile)(file);
return tf.tidy(() => {
let dtype = undefined;
let expandAnimations = options?.expandAnimations;
let imageTensor;
try {
imageTensor = tf.node.decodeImage(content, channels, dtype, expandAnimations);
}
catch (error) {
throw new Error('failed to decode image: ' + JSON.stringify(file), {
cause: error,
});
}
let embedding = imageTensorToEmbedding(imageTensor);
if (cache && (0, file_1.isContentHash)(file)) {
saveCache(file, embedding);
}
return embedding;
});
}
function imageTensorToEmbedding(imageTensor) {
return tf.tidy(() => {
imageTensor = (0, image_1.cropAndResizeImageTensor)({
imageTensor,
width,
height,
aspectRatio,
});
let outputs = model.predict(imageTensor);
let embedding = (0, tensor_1.toOneTensor)(outputs);
return embedding;
});
}
let spatialNodes = (0, spatial_utils_1.getSpatialNodes)({ model, tf });
let spatialNodesWithUniqueShapes = (0, spatial_utils_1.filterSpatialNodesWithUniqueShapes)(spatialNodes);
let lastSpatialNode = spatialNodesWithUniqueShapes.slice().pop();
return {
spec,
model,
fileEmbeddingCache,
checkCache,
loadImageCropped,
imageFileToEmbedding,
imageTensorToEmbedding,
spatialNodes,
spatialNodesWithUniqueShapes,
lastSpatialNode,
};
}
;