ml5-save
Version:
263 lines (224 loc) • 8.72 kB
JavaScript
// Copyright (c) 2018 ml5
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
/* eslint max-len: ["error", { "code": 180 }] */
/*
YOLO Object detection
Heavily derived from https://github.com/ModelDepot/tfjs-yolo-tiny (ModelDepot: modeldepot.io)
*/
import * as tf from '@tensorflow/tfjs';
import Video from './../../utils/Video';
import {
imgToTensor,
isInstanceOfSupportedElement
} from "./../../utils/imageUtilities";
import callCallback from './../../utils/callcallback';
import CLASS_NAMES from './../../utils/COCO_CLASSES';
import modelLoader from './../../utils/modelLoader';
import {
nonMaxSuppression,
boxesToCorners,
head,
filterBoxes,
} from './postprocess';
const DEFAULTS = {
modelUrl: 'https://raw.githubusercontent.com/ml5js/ml5-data-and-training/master/models/YOLO/model.json',
filterBoxesThreshold: 0.01,
IOUThreshold: 0.4,
classProbThreshold: 0.4,
};
// Size of the video
const imageSize = 416;
class YOLOBase extends Video {
/**
* @deprecated Please use ObjectDetector class instead
*/
/**
* @typedef {Object} options
* @property {number} filterBoxesThreshold - default 0.01
* @property {number} IOUThreshold - default 0.4
* @property {number} classProbThreshold - default 0.4
*/
/**
* Create YOLO model. Works on video and images.
* @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} video - Optional. The video to be used for object detection and classification.
* @param {Object} options - Optional. A set of options.
* @param {function} callback - Optional. A callback function that is called once the model has loaded.
*/
constructor(video, options, callback) {
super(video, imageSize);
this.modelUrl = options.modelUrl || DEFAULTS.modelUrl;
this.filterBoxesThreshold = options.filterBoxesThreshold || DEFAULTS.filterBoxesThreshold;
this.IOUThreshold = options.IOUThreshold || DEFAULTS.IOUThreshold;
this.classProbThreshold = options.classProbThreshold || DEFAULTS.classProbThreshold;
this.modelReady = false;
this.isPredicting = false;
this.callback = callback;
this.ready = callCallback(this.loadModel(), this.callback);
if (!options.disableDeprecationNotice) {
console.warn("WARNING! Function YOLO has been deprecated, please use the new ObjectDetector function instead");
}
}
async loadModel() {
if (this.videoElt && !this.video) {
this.video = await this.loadVideo();
}
if (modelLoader.isAbsoluteURL(this.modelUrl) === true) {
this.model = await tf.loadLayersModel(this.modelUrl);
} else {
const modelPath = modelLoader.getModelPath(this.modelUrl);
this.modelUrl = `${modelPath}/model.json`;
this.model = await tf.loadLayersModel(this.modelUrl);
}
this.modelReady = true;
return this;
}
/**
* Detect objects that are in video, returns bounding box, label, and confidence scores
* @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} subject - Subject of the detection.
* @param {function} callback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise
* that will be resolved once the prediction is done.
* @returns {ObjectDetectorPrediction}
*/
async detect(inputOrCallback, cb) {
await this.ready;
let imgToPredict;
let callback = cb;
if (isInstanceOfSupportedElement(inputOrCallback)) {
imgToPredict = inputOrCallback;
} else if (typeof inputOrCallback === "object" && isInstanceOfSupportedElement(inputOrCallback.elt)) {
imgToPredict = inputOrCallback.elt; // Handle p5.js image and video.
} else if (typeof inputOrCallback === "object" && isInstanceOfSupportedElement(inputOrCallback.canvas)) {
imgToPredict = inputOrCallback.canvas; // Handle p5.js image and video.
} else if (typeof inputOrCallback === "function") {
imgToPredict = this.video;
callback = inputOrCallback;
} else {
throw new Error('Detection subject not supported')
}
return callCallback(this.detectInternal(imgToPredict), callback);
}
/**
* @typedef {Object} ObjectDetectorPrediction
* @property {number} x - top left x coordinate of the prediction box in pixels.
* @property {number} y - top left y coordinate of the prediction box in pixels.
* @property {number} width - width of the prediction box in pixels.
* @property {number} height - height of the prediction box in pixels.
* @property {string} label - the label given.
* @property {number} confidence - the confidence score (0 to 1).
* @property {ObjectDetectorPredictionNormalized} normalized - a normalized object of the predicition
*/
/**
* @typedef {Object} ObjectDetectorPredictionNormalized
* @property {number} x - top left x coordinate of the prediction box (0 to 1).
* @property {number} y - top left y coordinate of the prediction box (0 to 1).
* @property {number} width - width of the prediction box (0 to 1).
* @property {number} height - height of the prediction box (0 to 1).
*/
/**
* Detect objects that are in video, returns bounding box, label, and confidence scores
* @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} subject - Subject of the detection.
* @returns {ObjectDetectorPrediction}
*/
async detectInternal(imgToPredict) {
await this.ready;
await tf.nextFrame();
const ANCHORS = tf.tensor2d([
[0.57273, 0.677385],
[1.87446, 2.06253],
[3.33843, 5.47434],
[7.88282, 3.52778],
[9.77052, 9.16828],
]);
this.isPredicting = true;
const [allBoxes, boxConfidence, boxClassProbs] = tf.tidy(() => {
const input = imgToTensor(imgToPredict, [imageSize, imageSize]);
const activation = this.model.predict(input);
const [boxXY, boxWH, bConfidence, bClassProbs] = head(activation, ANCHORS, 80);
const aBoxes = boxesToCorners(boxXY, boxWH);
return [aBoxes, bConfidence, bClassProbs];
});
const [boxes, scores, classes] = await filterBoxes(allBoxes, boxConfidence, boxClassProbs, this.filterBoxesThreshold);
allBoxes.dispose();
boxConfidence.dispose();
boxClassProbs.dispose();
// If all boxes have been filtered out
if (boxes == null) {
return [];
}
return tf.tidy(() => {
const width = tf.scalar(imageSize);
const height = tf.scalar(imageSize);
const imageDims = tf.stack([height, width, height, width]).reshape([1, 4]);
const boxesModified = tf.mul(boxes, imageDims);
const preKeepBoxesArr = boxesModified.dataSync();
const scoresArr = scores.dataSync();
const [keepIndx, boxesArr, keepScores] = nonMaxSuppression(
preKeepBoxesArr,
scoresArr,
this.IOUThreshold,
);
const classesIndxArr = classes.gather(tf.tensor1d(keepIndx, 'int32')).dataSync();
const results = [];
classesIndxArr.forEach((classIndx, i) => {
const classProb = keepScores[i];
if (classProb < this.classProbThreshold) {
return;
}
const className = CLASS_NAMES[classIndx];
let [y, x, h, w] = boxesArr[i];
y = Math.max(0, y);
x = Math.max(0, x);
h = Math.min(imageSize, h) - y;
w = Math.min(imageSize, w) - x;
const resultObj = {
label: className,
confidence: classProb,
x,
y,
width: w,
height: h,
normalized: {
x: x / imageSize,
y: y / imageSize,
width: w / imageSize,
height: h / imageSize,
}
};
results.push(resultObj);
});
this.isPredicting = false;
width.dispose()
height.dispose()
imageDims.dispose()
boxesModified.dispose()
boxes.dispose();
scores.dispose();
classes.dispose();
ANCHORS.dispose();
return results;
})
}
}
const YOLO = (videoOr, optionsOr, cb) => {
let video = null;
let options = {};
let callback = cb;
if (videoOr instanceof HTMLVideoElement) {
video = videoOr;
} else if (typeof videoOr === 'object' && videoOr.elt instanceof HTMLVideoElement) {
video = videoOr.elt; // Handle p5.js image
} else if (typeof videoOr === 'function') {
callback = videoOr;
} else if (typeof videoOr === 'object') {
options = videoOr;
}
if (typeof optionsOr === 'object') {
options = optionsOr;
} else if (typeof optionsOr === 'function') {
callback = optionsOr;
}
return new YOLOBase(video, options, callback);
};
export default YOLO;