UNPKG

tensorflow-helpers

Version:

Helper functions to use tensorflow in nodejs for transfer learning, image classification, and more

128 lines (127 loc) 4.66 kB
"use strict"; var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { if (k2 === undefined) k2 = k; var desc = Object.getOwnPropertyDescriptor(m, k); if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { desc = { enumerable: true, get: function() { return m[k]; } }; } Object.defineProperty(o, k2, desc); }) : (function(o, m, k, k2) { if (k2 === undefined) k2 = k; o[k2] = m[k]; })); var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { Object.defineProperty(o, "default", { enumerable: true, value: v }); }) : function(o, v) { o["default"] = v; }); var __importStar = (this && this.__importStar) || (function () { var ownKeys = function(o) { ownKeys = Object.getOwnPropertyNames || function (o) { var ar = []; for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k; return ar; }; return ownKeys(o); }; return function (mod) { if (mod && mod.__esModule) return mod; var result = {}; if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]); __setModuleDefault(result, mod); return result; }; })(); Object.defineProperty(exports, "__esModule", { value: true }); exports.createImageClassifier = createImageClassifier; exports.getClassCount = getClassCount; exports.topClassifyResult = topClassifyResult; exports.mapWithClassName = mapWithClassName; exports.calcClassWeight = calcClassWeight; exports.checkClassNames = checkClassNames; exports.attachClassNames = attachClassNames; const tf = __importStar(require("@tensorflow/tfjs")); function createImageClassifier(spec) { let { hiddenLayers, classNames } = spec; if (spec.classes < 2) { throw new Error('image classifier must be at least 2 classes'); } if (classNames && classNames.length !== spec.classes) { throw new Error('classNames length mismatch'); } let classifierModel = tf.sequential(); classifierModel.add(tf.layers.inputLayer({ inputShape: [spec.embeddingFeatures] })); classifierModel.add(tf.layers.dropout({ rate: 0.5 })); if (hiddenLayers) { for (let i = 0; i < hiddenLayers.length; i++) { classifierModel.add(tf.layers.dense({ units: hiddenLayers[i], activation: 'gelu' })); classifierModel.add(tf.layers.dropout({ rate: 0.5 })); } } classifierModel.add(tf.layers.dense({ units: spec.classes, activation: 'softmax' })); return attachClassNames(classifierModel, classNames); } function getClassCount(shape) { for (;;) { let value = shape[0]; for (let i = 1; i < shape.length; i++) { value = shape[i] || value; } if (Array.isArray(value)) { shape = value; continue; } if (!value) { throw new Error('failed to get class count'); } return value; } } function topClassifyResult(items) { let idx = 0; let max = items[idx]; for (let i = 1; i < items.length; i++) { let item = items[i]; if (item.confidence > max.confidence) { max = item; } } return max; } /** * @description the values is returned as is. * It should has be applied softmax already. * */ function mapWithClassName(classNames, values, options) { let result = new Array(classNames.length); for (let i = 0; i < classNames.length; i++) { result[i] = { label: classNames[i], confidence: values[i], }; } if (options?.sort) { result.sort((a, b) => b.confidence - a.confidence); } return result; } function calcClassWeight(options) { let total = options.classCounts.reduce((acc, c) => acc + c, 0); let classWeights = options.classCounts.map(count => total / options.classes / count); return classWeights; } function checkClassNames(modelArtifact, classNames) { if (classNames && modelArtifact.classNames) { let expected = JSON.stringify(classNames); let actual = JSON.stringify(modelArtifact.classNames); if (actual !== expected) { throw new Error(`classNames mismatch, expected: ${expected}, actual: ${actual}`); } } return !classNames && modelArtifact.classNames ? modelArtifact.classNames : classNames; } function attachClassNames(model, classNames) { return classNames ? Object.assign(model, { classNames }) : model; }