tensorflow-helpers
Version:
Helper functions to use tensorflow in nodejs for transfer learning, image classification, and more
128 lines (127 loc) • 4.66 kB
JavaScript
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
var desc = Object.getOwnPropertyDescriptor(m, k);
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
desc = { enumerable: true, get: function() { return m[k]; } };
}
Object.defineProperty(o, k2, desc);
}) : (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
o[k2] = m[k];
}));
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
Object.defineProperty(o, "default", { enumerable: true, value: v });
}) : function(o, v) {
o["default"] = v;
});
var __importStar = (this && this.__importStar) || (function () {
var ownKeys = function(o) {
ownKeys = Object.getOwnPropertyNames || function (o) {
var ar = [];
for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
return ar;
};
return ownKeys(o);
};
return function (mod) {
if (mod && mod.__esModule) return mod;
var result = {};
if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
__setModuleDefault(result, mod);
return result;
};
})();
Object.defineProperty(exports, "__esModule", { value: true });
exports.createImageClassifier = createImageClassifier;
exports.getClassCount = getClassCount;
exports.topClassifyResult = topClassifyResult;
exports.mapWithClassName = mapWithClassName;
exports.calcClassWeight = calcClassWeight;
exports.checkClassNames = checkClassNames;
exports.attachClassNames = attachClassNames;
const tf = __importStar(require("@tensorflow/tfjs"));
function createImageClassifier(spec) {
let { hiddenLayers, classNames } = spec;
if (spec.classes < 2) {
throw new Error('image classifier must be at least 2 classes');
}
if (classNames && classNames.length !== spec.classes) {
throw new Error('classNames length mismatch');
}
let classifierModel = tf.sequential();
classifierModel.add(tf.layers.inputLayer({ inputShape: [spec.embeddingFeatures] }));
classifierModel.add(tf.layers.dropout({ rate: 0.5 }));
if (hiddenLayers) {
for (let i = 0; i < hiddenLayers.length; i++) {
classifierModel.add(tf.layers.dense({ units: hiddenLayers[i], activation: 'gelu' }));
classifierModel.add(tf.layers.dropout({ rate: 0.5 }));
}
}
classifierModel.add(tf.layers.dense({ units: spec.classes, activation: 'softmax' }));
return attachClassNames(classifierModel, classNames);
}
function getClassCount(shape) {
for (;;) {
let value = shape[0];
for (let i = 1; i < shape.length; i++) {
value = shape[i] || value;
}
if (Array.isArray(value)) {
shape = value;
continue;
}
if (!value) {
throw new Error('failed to get class count');
}
return value;
}
}
function topClassifyResult(items) {
let idx = 0;
let max = items[idx];
for (let i = 1; i < items.length; i++) {
let item = items[i];
if (item.confidence > max.confidence) {
max = item;
}
}
return max;
}
/**
* @description the values is returned as is.
* It should has be applied softmax already.
* */
function mapWithClassName(classNames, values, options) {
let result = new Array(classNames.length);
for (let i = 0; i < classNames.length; i++) {
result[i] = {
label: classNames[i],
confidence: values[i],
};
}
if (options?.sort) {
result.sort((a, b) => b.confidence - a.confidence);
}
return result;
}
function calcClassWeight(options) {
let total = options.classCounts.reduce((acc, c) => acc + c, 0);
let classWeights = options.classCounts.map(count => total / options.classes / count);
return classWeights;
}
function checkClassNames(modelArtifact, classNames) {
if (classNames && modelArtifact.classNames) {
let expected = JSON.stringify(classNames);
let actual = JSON.stringify(modelArtifact.classNames);
if (actual !== expected) {
throw new Error(`classNames mismatch, expected: ${expected}, actual: ${actual}`);
}
}
return !classNames && modelArtifact.classNames
? modelArtifact.classNames
: classNames;
}
function attachClassNames(model, classNames) {
return classNames ? Object.assign(model, { classNames }) : model;
}
;