UNPKG

tensorflow-helpers

Version:

Helper functions to use tensorflow in nodejs for transfer learning, image classification, and more

255 lines (254 loc) 9.99 kB
"use strict"; var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { if (k2 === undefined) k2 = k; var desc = Object.getOwnPropertyDescriptor(m, k); if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { desc = { enumerable: true, get: function() { return m[k]; } }; } Object.defineProperty(o, k2, desc); }) : (function(o, m, k, k2) { if (k2 === undefined) k2 = k; o[k2] = m[k]; })); var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { Object.defineProperty(o, "default", { enumerable: true, value: v }); }) : function(o, v) { o["default"] = v; }); var __importStar = (this && this.__importStar) || (function () { var ownKeys = function(o) { ownKeys = Object.getOwnPropertyNames || function (o) { var ar = []; for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k; return ar; }; return ownKeys(o); }; return function (mod) { if (mod && mod.__esModule) return mod; var result = {}; if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]); __setModuleDefault(result, mod); return result; }; })(); Object.defineProperty(exports, "__esModule", { value: true }); exports.PreTrainedImageModels = void 0; exports.saveModel = saveModel; exports.loadGraphModel = loadGraphModel; exports.loadLayersModel = loadLayersModel; exports.cachedLoadGraphModel = cachedLoadGraphModel; exports.cachedLoadLayersModel = cachedLoadLayersModel; exports.loadImageModel = loadImageModel; const tf = __importStar(require("@tensorflow/tfjs-node")); const fs_1 = require("fs"); const promises_1 = require("fs/promises"); const path_1 = require("path"); const image_1 = require("./image"); const file_1 = require("./file"); const tensor_1 = require("./tensor"); const classifier_utils_1 = require("./classifier-utils"); const spatial_utils_1 = require("./spatial-utils"); var image_model_1 = require("./image-model"); Object.defineProperty(exports, "PreTrainedImageModels", { enumerable: true, get: function () { return image_model_1.PreTrainedImageModels; } }); async function saveModel(options) { let { dir, model, classNames } = options; return await model.save({ async save(modelArtifact) { if (classNames) { modelArtifact.classNames = classNames; } let weights = modelArtifact.weightData; if (!weights) { throw new Error('missing weightData'); } if (!Array.isArray(weights)) { weights = [weights]; } (0, fs_1.mkdirSync)(dir, { recursive: true }); (0, promises_1.writeFile)((0, path_1.join)(dir, 'model.json'), JSON.stringify(modelArtifact)); for (let i = 0; i < weights.length; i++) { (0, promises_1.writeFile)((0, path_1.join)(dir, `weight-${i}.bin`), Buffer.from(weights[i])); } return { modelArtifactsInfo: { dateSaved: new Date(), modelTopologyType: 'JSON', }, }; }, }); } async function loadGraphModel(options) { let { dir, classNames } = options; let model = await tf.loadGraphModel({ async load() { let buffer = await (0, promises_1.readFile)((0, path_1.join)(dir, 'model.json')); let modelArtifact = JSON.parse(buffer.toString()); classNames = (0, classifier_utils_1.checkClassNames)(modelArtifact, classNames); let weights = modelArtifact.weightData; if (!weights) { throw new Error('missing weightData'); } if (!Array.isArray(weights)) { weights = [weights]; } for (let i = 0; i < weights.length; i++) { weights[i] = await loadWeightData((0, path_1.join)(dir, `weight-${i}.bin`)); } return modelArtifact; }, }); return (0, classifier_utils_1.attachClassNames)(model, classNames); } async function loadLayersModel(options) { let { dir, classNames } = options; let model = await tf.loadLayersModel({ async load() { let buffer = await (0, promises_1.readFile)((0, path_1.join)(dir, 'model.json')); let modelArtifact = JSON.parse(buffer.toString()); classNames = (0, classifier_utils_1.checkClassNames)(modelArtifact, classNames); let weights = modelArtifact.weightData; if (!weights) { throw new Error('missing weightData'); } if (!Array.isArray(weights)) { modelArtifact.weightData = await loadWeightData((0, path_1.join)(dir, `weight-0.bin`)); return modelArtifact; } for (let i = 0; i < weights.length; i++) { weights[i] = await loadWeightData((0, path_1.join)(dir, `weight-${i}.bin`)); } return modelArtifact; }, }); if (classNames) { let classCount = (0, classifier_utils_1.getClassCount)(model.outputShape); if (classCount != classNames.length) { throw new Error('number of classes mismatch, expected: ' + classNames.length + ', got: ' + classCount); } } return (0, classifier_utils_1.attachClassNames)(model, classNames); } async function loadWeightData(file) { let buffer = await (0, promises_1.readFile)(file); return new Uint8Array(buffer); } async function cachedLoadGraphModel(options) { let { url: modelUrl, dir: modelDir, classNames } = options; if ((0, fs_1.existsSync)(modelDir)) { return await loadGraphModel(options); } let model = await tf.loadGraphModel(modelUrl, { fromTFHub: true }); await saveModel({ model, dir: modelDir, classNames }); return (0, classifier_utils_1.attachClassNames)(model, classNames); } async function cachedLoadLayersModel(options) { let { url: modelUrl, dir: modelDir, classNames } = options; if ((0, fs_1.existsSync)(modelDir)) { return await loadLayersModel(options); } let model = await tf.loadLayersModel(modelUrl, { fromTFHub: true }); await saveModel({ model, dir: modelDir, classNames }); return (0, classifier_utils_1.attachClassNames)(model, classNames); } async function loadImageModel(options) { let { spec, dir, aspectRatio, cache } = options; let { url, width, height, channels } = spec; let model = await cachedLoadGraphModel({ url, dir, }); async function loadImageCropped(file, options) { let imageTensor = await (0, image_1.loadImageFile)(file, { channels, expandAnimations: options?.expandAnimations, crop: { width, height, aspectRatio, }, }); return imageTensor; } let fileEmbeddingCache = cache ? new Map() : null; function checkCache(file_or_filename) { if (!fileEmbeddingCache || !(0, file_1.isContentHash)(file_or_filename)) return; let filename = (0, path_1.basename)(file_or_filename); let embedding = fileEmbeddingCache.get(filename); if (embedding) return embedding; let values = typeof cache == 'object' ? cache.get(filename) : undefined; if (!values) return; embedding = tf.tensor([values]); fileEmbeddingCache.set(filename, embedding); return embedding; } async function saveCache(file, embedding) { let filename = (0, path_1.basename)(file); fileEmbeddingCache.set(filename, embedding); if (typeof cache == 'object') { let values = Array.from(await embedding.data()); cache.set(filename, values); } } async function imageFileToEmbedding(file, options) { let embedding = checkCache(file); if (embedding) return embedding; let content = await (0, promises_1.readFile)(file); return tf.tidy(() => { let dtype = undefined; let expandAnimations = options?.expandAnimations; let imageTensor; try { imageTensor = tf.node.decodeImage(content, channels, dtype, expandAnimations); } catch (error) { throw new Error('failed to decode image: ' + JSON.stringify(file), { cause: error, }); } let embedding = imageTensorToEmbedding(imageTensor); if (cache && (0, file_1.isContentHash)(file)) { saveCache(file, embedding); } return embedding; }); } function imageTensorToEmbedding(imageTensor) { return tf.tidy(() => { imageTensor = (0, image_1.cropAndResizeImageTensor)({ imageTensor, width, height, aspectRatio, }); let outputs = model.predict(imageTensor); let embedding = (0, tensor_1.toOneTensor)(outputs); return embedding; }); } let spatialNodes = (0, spatial_utils_1.getSpatialNodes)({ model, tf }); let spatialNodesWithUniqueShapes = (0, spatial_utils_1.filterSpatialNodesWithUniqueShapes)(spatialNodes); let lastSpatialNode = spatialNodesWithUniqueShapes.slice().pop(); return { spec, model, fileEmbeddingCache, checkCache, loadImageCropped, imageFileToEmbedding, imageTensorToEmbedding, spatialNodes, spatialNodesWithUniqueShapes, lastSpatialNode, }; }