ml5-save
Version:
288 lines (257 loc) • 10.4 kB
JavaScript
// Copyright (c) 2019 ml5
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
/*
Image Classifier using pre-trained networks
*/
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
import * as darknet from './darknet';
import * as doodlenet from './doodlenet';
import callCallback from '../utils/callcallback';
import { imgToTensor } from '../utils/imageUtilities';
const DEFAULTS = {
mobilenet: {
version: 2,
alpha: 1.0,
topk: 3,
},
};
const IMAGE_SIZE = 224;
const MODEL_OPTIONS = ['mobilenet', 'darknet', 'darknet-tiny', 'doodlenet'];
class ImageClassifier {
/**
* Create an ImageClassifier.
* @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options
* are: 'mobilenet', 'darknet', 'darknet-tiny', and 'doodlenet'.
* @param {HTMLVideoElement} video - An HTMLVideoElement.
* @param {object} options - An object with options.
* @param {function} callback - A callback to be called when the model is ready.
*/
constructor(modelNameOrUrl, video, options, callback) {
this.video = video;
this.model = null;
this.mapStringToIndex = [];
if (typeof modelNameOrUrl === 'string') {
if (MODEL_OPTIONS.includes(modelNameOrUrl)) {
this.modelName = modelNameOrUrl;
this.modelUrl = null;
switch (this.modelName) {
case 'mobilenet':
this.modelToUse = mobilenet;
this.version = options.version || DEFAULTS.mobilenet.version;
this.alpha = options.alpha || DEFAULTS.mobilenet.alpha;
this.topk = options.topk || DEFAULTS.mobilenet.topk;
break;
case 'darknet':
this.version = 'reference'; // this a 28mb model
this.modelToUse = darknet;
break;
case 'darknet-tiny':
this.version = 'tiny'; // this a 4mb model
this.modelToUse = darknet;
break;
case 'doodlenet':
this.modelToUse = doodlenet;
break;
default:
this.modelToUse = null;
}
} else {
this.modelUrl = modelNameOrUrl;
}
}
// Load the model
this.ready = callCallback(this.loadModel(this.modelUrl), callback);
}
/**
* Load the model and set it to this.model
* @return {this} The ImageClassifier.
*/
async loadModel(modelUrl) {
if (modelUrl) this.model = await this.loadModelFrom(modelUrl);
else this.model = await this.modelToUse.load({ version: this.version, alpha: this.alpha });
return this;
}
async loadModelFrom(path = null) {
fetch(path)
.then(r => r.json())
.then((r) => {
if (r.ml5Specs) {
this.mapStringToIndex = r.ml5Specs.mapStringToIndex;
}
})
// When loading model generated by Teachable Machine 2.0, the r.ml5Specs is missing,
// which is causing imageClassifier failing to display lables.
// In this case, labels are stored in path/./metadata.json
// Therefore, I'm fetching the metadata and feeding the labels into this.mapStringToIndex
// by Yang Yang, yy2473@nyu.edu, Oct 2, 2019
.then(() => {
if (this.mapStringToIndex.length === 0) {
const split = path.split("/");
const prefix = split.slice(0, split.length - 1).join("/");
const metadataUrl = `${prefix}/metadata.json`;
fetch(metadataUrl)
.then((res) => {
if (!res.ok) {
console.log("Tried to fetch metadata.json, but it seems to be missing.");
throw Error(res.statusText);
}
return res;
})
.then(metadataJson => metadataJson.json())
.then((metadataJson) => {
if (metadataJson.labels) {
this.mapStringToIndex = metadataJson.labels;
}
})
.catch(() => console.log("Error when loading metadata.json"));
}
});
// end of the Oct 2, 2019 fix
this.model = await tf.loadLayersModel(path);
return this.model;
}
/**
* Classifies the given input and returns an object with labels and confidence
* @param {HTMLImageElement | HTMLCanvasElement | HTMLVideoElement} imgToPredict -
* takes an image to run the classification on.
* @param {number} numberOfClasses - a number of labels to return for the image
* classification.
* @return {object} an object with {label, confidence}.
*/
async classifyInternal(imgToPredict, numberOfClasses) {
// Wait for the model to be ready
await this.ready;
await tf.nextFrame();
if (imgToPredict instanceof HTMLVideoElement && imgToPredict.readyState === 0) {
const video = imgToPredict;
// Wait for the video to be ready
await new Promise(resolve => {
video.onloadeddata = () => resolve();
});
}
if (this.video && this.video.readyState === 0) {
await new Promise(resolve => {
this.video.onloadeddata = () => resolve();
});
}
// Process the images
const imageResize = [IMAGE_SIZE, IMAGE_SIZE];
if (this.modelUrl) {
await tf.nextFrame();
const predictedClasses = tf.tidy(() => {
const processedImg = imgToTensor(imgToPredict, imageResize);
const predictions = this.model.predict(processedImg);
return Array.from(predictions.as1D().dataSync());
});
const results = await predictedClasses.map((confidence, index) => {
const label = (this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]) ? this.mapStringToIndex[index] : index;
return {
label,
confidence,
};
}).sort((a, b) => b.confidence - a.confidence);
return results;
}
const processedImg = imgToTensor(imgToPredict, imageResize);
const results = this.model
.classify(processedImg, numberOfClasses)
.then(classes => classes.map(c => ({ label: c.className, confidence: c.probability })));
processedImg.dispose();
return results;
}
/**
* Classifies the given input and takes a callback to handle the results
* @param {HTMLImageElement | HTMLCanvasElement | object | function | number} inputNumOrCallback -
* takes any of the following params
* @param {HTMLImageElement | HTMLCanvasElement | object | function | number} numOrCallback -
* takes any of the following params
* @param {function} cb - a callback function that handles the results of the function.
* @return {function} a promise or the results of a given callback, cb.
*/
async classify(inputNumOrCallback, numOrCallback = null, cb) {
let imgToPredict = this.video;
let numberOfClasses = this.topk;
let callback;
// Handle the image to predict
if (typeof inputNumOrCallback === 'function') {
imgToPredict = this.video;
callback = inputNumOrCallback;
} else if (typeof inputNumOrCallback === 'number') {
imgToPredict = this.video;
numberOfClasses = inputNumOrCallback;
} else if (inputNumOrCallback instanceof HTMLVideoElement
|| inputNumOrCallback instanceof HTMLImageElement
|| inputNumOrCallback instanceof HTMLCanvasElement
|| inputNumOrCallback instanceof ImageData) {
imgToPredict = inputNumOrCallback;
} else if (
typeof inputNumOrCallback === 'object' &&
(inputNumOrCallback.elt instanceof HTMLVideoElement
|| inputNumOrCallback.elt instanceof HTMLImageElement
|| inputNumOrCallback.elt instanceof HTMLCanvasElement
|| inputNumOrCallback.elt instanceof ImageData)
) {
imgToPredict = inputNumOrCallback.elt; // Handle p5.js image
} else if (typeof inputNumOrCallback === 'object' && inputNumOrCallback.canvas instanceof HTMLCanvasElement) {
imgToPredict = inputNumOrCallback.canvas; // Handle p5.js image
} else if (!(this.video instanceof HTMLVideoElement)) {
// Handle unsupported input
throw new Error(
'No input image provided. If you want to classify a video, pass the video element in the constructor. ',
);
}
if (typeof numOrCallback === 'number') {
numberOfClasses = numOrCallback;
} else if (typeof numOrCallback === 'function') {
callback = numOrCallback;
}
if (typeof cb === 'function') {
callback = cb;
}
return callCallback(this.classifyInternal(imgToPredict, numberOfClasses), callback);
}
/**
* Will be deprecated soon in favor of ".classify()" - does the same as .classify()
* @param {HTMLImageElement | HTMLCanvasElement | object | function | number} inputNumOrCallback - takes any of the following params
* @param {HTMLImageElement | HTMLCanvasElement | object | function | number} numOrCallback - takes any of the following params
* @param {function} cb - a callback function that handles the results of the function.
* @return {function} a promise or the results of a given callback, cb.
*/
async predict(inputNumOrCallback, numOrCallback, cb) {
return this.classify(inputNumOrCallback, numOrCallback || null, cb);
}
}
const imageClassifier = (modelName, videoOrOptionsOrCallback, optionsOrCallback, cb) => {
let video;
let options = {};
let callback = cb;
let model = modelName;
if (typeof model !== 'string') {
throw new Error('Please specify a model to use. E.g: "MobileNet"');
} else if (model.indexOf('http') === -1) {
model = modelName.toLowerCase();
}
if (videoOrOptionsOrCallback instanceof HTMLVideoElement) {
video = videoOrOptionsOrCallback;
} else if (
typeof videoOrOptionsOrCallback === 'object' &&
videoOrOptionsOrCallback.elt instanceof HTMLVideoElement
) {
video = videoOrOptionsOrCallback.elt; // Handle a p5.js video element
} else if (typeof videoOrOptionsOrCallback === 'object') {
options = videoOrOptionsOrCallback;
} else if (typeof videoOrOptionsOrCallback === 'function') {
callback = videoOrOptionsOrCallback;
}
if (typeof optionsOrCallback === 'object') {
options = optionsOrCallback;
} else if (typeof optionsOrCallback === 'function') {
callback = optionsOrCallback;
}
const instance = new ImageClassifier(model, video, options, callback);
return callback ? instance : instance.ready;
};
export default imageClassifier;