UNPKG

ml5-save

Version:
218 lines (188 loc) 7.06 kB
// Copyright (c) 2018 ml5 // // This software is released under the MIT License. // https://opensource.org/licenses/MIT /* Image Classifier using pre-trained networks */ import * as tf from '@tensorflow/tfjs'; import callCallback from '../utils/callcallback'; import { array3DToImage } from '../utils/imageUtilities'; import p5Utils from '../utils/p5Utils'; const DEFAULTS = { modelPath: 'https://raw.githubusercontent.com/zaidalyafeai/HostedModels/master/unet-128/model.json', imageSize: 128, returnTensors: false, } class UNET { /** * Create UNET class. * @param {HTMLVideoElement | HTMLImageElement} video - The video or image to be used for segmentation. * @param {Object} options - Optional. A set of options. * @param {function} callback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise * that will be resolved once the model has loaded. */ constructor(video, options, callback) { this.modelReady = false; this.isPredicting = false; this.config = { modelPath: typeof options.modelPath !== 'undefined' ? options.modelPath : DEFAULTS.modelPath, imageSize: typeof options.imageSize !== 'undefined' ? options.imageSize : DEFAULTS.imageSize, returnTensors: typeof options.returnTensors !== 'undefined' ? options.returnTensors : DEFAULTS.returnTensors, }; this.ready = callCallback(this.loadModel(), callback); } async loadModel() { this.model = await tf.loadLayersModel(this.config.modelPath); this.modelReady = true; return this; } async segment(inputOrCallback, cb) { await this.ready; let imgToPredict; let callback = cb; if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement || inputOrCallback instanceof HTMLCanvasElement || inputOrCallback instanceof ImageData) { imgToPredict = inputOrCallback; } else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement || inputOrCallback.elt instanceof HTMLCanvasElement || inputOrCallback.elt instanceof ImageData)) { imgToPredict = inputOrCallback.elt; } else if (typeof inputOrCallback === 'function') { imgToPredict = this.video; callback = inputOrCallback; } return callCallback(this.segmentInternal(imgToPredict), callback); } static dataURLtoBlob(dataurl) { const arr = dataurl.split(','); const mime = arr[0].match(/:(.*?);/)[1]; const bstr = atob(arr[1]); let n = bstr.length; const u8arr = new Uint8Array(n); while (n) { u8arr[n] = bstr.charCodeAt(n); n -= 1; } return new Blob([u8arr], { type: mime }); } async convertToP5Image(tfBrowserPixelImage){ const blob1 = await p5Utils.rawToBlob(tfBrowserPixelImage, this.config.imageSize, this.config.imageSize); const p5Image1 = await p5Utils.blobToP5Image(blob1); return p5Image1 } async segmentInternal(imgToPredict) { // Wait for the model to be ready await this.ready; // skip asking for next frame if it's not video if (imgToPredict instanceof HTMLVideoElement) { await tf.nextFrame(); } this.isPredicting = true; const { featureMask, backgroundMask, segmentation } = tf.tidy(() => { // preprocess the input image const tfImage = tf.browser.fromPixels(imgToPredict).toFloat(); const resizedImg = tf.image.resizeBilinear(tfImage, [this.config.imageSize, this.config.imageSize]); let normTensor = resizedImg.div(tf.scalar(255)); const batchedImage = normTensor.expandDims(0); // get the segmentation const pred = this.model.predict(batchedImage); // add back the alpha channel to the normalized input image const alpha = tf.ones([128, 128, 1]).tile([1,1,1]) normTensor = normTensor.concat(alpha, 2) // TODO: optimize these redundancies below, e.g. repetitive squeeze() etc // get the background mask; let maskBackgroundInternal = pred.squeeze([0]); maskBackgroundInternal = maskBackgroundInternal.tile([1, 1, 4]); maskBackgroundInternal = maskBackgroundInternal.sub(0.3).sign().relu().neg().add(1); const featureMaskInternal = maskBackgroundInternal.mul(normTensor); // get the feature mask; let maskFeature = pred.squeeze([0]); maskFeature = maskFeature.tile([1, 1, 4]); maskFeature = maskFeature.sub(0.3).sign().relu(); const backgroundMaskInternal = maskFeature.mul(normTensor); const alpha255 = tf.ones([128, 128, 1]).tile([1,1,1]).mul(255); let newpred = pred.squeeze([0]); newpred = tf.cast(newpred.tile([1,1,3]).sub(0.3).sign().relu().mul(255), 'int32') newpred = newpred.concat(alpha255, 2) return { featureMask: featureMaskInternal, backgroundMask: backgroundMaskInternal, segmentation: newpred }; }); this.isPredicting = false; // these come first because array3DToImage() will dispose of the input tensor const maskFeat = await tf.browser.toPixels(featureMask); const maskBg = await tf.browser.toPixels(backgroundMask); const mask = await tf.browser.toPixels(segmentation); const maskFeatDom = array3DToImage(featureMask); const maskBgDom = array3DToImage(backgroundMask); const maskFeatBlob = UNET.dataURLtoBlob(maskFeatDom.src); const maskBgBlob = UNET.dataURLtoBlob(maskBgDom.src); let pFeatureMask; let pBgMask; let pMask; if (p5Utils.checkP5()) { pFeatureMask = await this.convertToP5Image(maskFeat); pBgMask = await this.convertToP5Image(maskBg) pMask = await this.convertToP5Image(mask) } if(!this.config.returnTensors){ featureMask.dispose(); backgroundMask.dispose(); segmentation.dispose(); } return { segmentation:mask, blob: { featureMask: maskFeatBlob, backgroundMask: maskBgBlob }, tensor: { featureMask, backgroundMask, }, raw: { featureMask: maskFeat, backgroundMask: maskBg }, // returns if p5 is available featureMask: pFeatureMask, backgroundMask: pBgMask, mask: pMask }; } } const uNet = (videoOr, optionsOr, cb) => { let video = null; let options = {}; let callback = cb; if (videoOr instanceof HTMLVideoElement) { video = videoOr; } else if (typeof videoOr === 'object' && videoOr.elt instanceof HTMLVideoElement) { video = videoOr.elt; // Handle p5.js image } else if (typeof videoOr === 'function') { callback = videoOr; } else if (typeof videoOr === 'object') { options = videoOr; } if (typeof optionsOr === 'object') { options = optionsOr; } else if (typeof optionsOr === 'function') { callback = optionsOr; } return new UNET(video, options, callback); }; export default uNet;