UNPKG

ml5-save

Version:
568 lines (513 loc) 19 kB
// Copyright (c) 2018 ml5 // // This software is released under the MIT License. // https://opensource.org/licenses/MIT /* A class that extract features from Mobilenet */ import * as tf from '@tensorflow/tfjs'; import Video from './../utils/Video'; import { imgToTensor } from '../utils/imageUtilities'; import { saveBlob } from '../utils/io'; import callCallback from '../utils/callcallback'; const IMAGE_SIZE = 224; const BASE_URL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v'; const DEFAULTS = { version: 1, alpha: 0.25, topk: 3, learningRate: 0.0001, hiddenUnits: 100, epochs: 20, numLabels: 2, batchSize: 0.4, layer: 'conv_pw_13_relu', }; const MODEL_INFO = { 1: { 0.25: 'https://tfhub.dev/google/imagenet/mobilenet_v1_025_224/classification/1', 0.50: 'https://tfhub.dev/google/imagenet/mobilenet_v1_050_224/classification/1', 0.75: 'https://tfhub.dev/google/imagenet/mobilenet_v1_075_224/classification/1', 1.00: 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' }, 2: { 0.50: 'https://tfhub.dev/google/imagenet/mobilenet_v2_050_224/classification/2', 0.75: 'https://tfhub.dev/google/imagenet/mobilenet_v2_075_224/classification/2', 1.00: 'https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2' } }; const EMBEDDING_NODES = { 1: 'module_apply_default/MobilenetV1/Logits/global_pool', 2: 'module_apply_default/MobilenetV2/Logits/AvgPool' }; class Mobilenet { constructor(options, callback) { this.mobilenet = null; this.topKPredictions = 10; /** * Boolean value that specifies if new data has been added to the model * @type {boolean} * @public */ this.hasAnyTrainedClass = false; this.customModel = null; this.jointModel = null; this.config = { epochs: options.epochs || DEFAULTS.epochs, version: options.version || DEFAULTS.version, hiddenUnits: options.hiddenUnits || DEFAULTS.hiddenUnits, numLabels: options.numLabels || DEFAULTS.numLabels, learningRate: options.learningRate || DEFAULTS.learningRate, batchSize: options.batchSize || DEFAULTS.batchSize, layer: options.layer || DEFAULTS.layer, alpha: options.alpha || DEFAULTS.alpha, } // for graph model this.model = null; this.url = MODEL_INFO[this.config.version][this.config.alpha]; this.normalizationOffset = tf.scalar(127.5); // check if a mobilenet URL is given this.mobilenetURL = options.mobilenetURL || `${BASE_URL}${this.config.version}_${this.config.alpha}_${IMAGE_SIZE}/model.json`; this.graphModelURL = options.graphModelURL || this.url; /** * Boolean value to check if the model is predicting. * @public * @type {boolean} */ this.isPredicting = false; this.mapStringToIndex = []; /** * String that specifies how is the Extractor being used. * Possible values are 'regressor' and 'classifier' * @type {String} * @public */ this.usageType = null; this.ready = callCallback(this.loadModel(), callback); } async loadModel() { this.mobilenet = await tf.loadLayersModel(this.mobilenetURL); if(this.graphModelURL.includes('https://tfhub.dev/')){ this.model = await tf.loadGraphModel(this.graphModelURL, {fromTFHub: true}); } else { this.model = await tf.loadGraphModel(this.graphModelURL, {fromTFHub: false}); } const layer = this.mobilenet.getLayer(this.config.layer); this.mobilenetFeatures = await tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output }); if (this.video) { await this.mobilenetFeatures.predict(imgToTensor(this.video)); // Warm up } return this; } /** * Use the features of MobileNet as a classifier. * @param {HTMLVideoElement || p5.Video} video - Optional. * An HTML video element or a p5.js video element. * @param {Object || function} objOrCallback - Optional. * Callback function or config object. * @param {function} callback - Optional. A function to be called once * the video is ready. If no callback is provided, it will return a * promise that will be resolved once the video element has loaded. */ classification(video, objOrCallback = null, callback) { let cb; this.usageType = 'classifier'; if (typeof objOrCallback === 'object') { Object.assign(this.config, objOrCallback); } else if (typeof objOrCallback === 'function') { cb = objOrCallback; } if (typeof callback === 'function') { cb = callback; } if (video) { callCallback(this.loadVideo(video), cb); } return this; } /** * Use the features of MobileNet as a regressor. * @param {HTMLVideoElement || p5.Video} video - Optional. * An HTML video element or a p5.js video element. * @param {function} callback - Optional. A function to be called once * the video is ready. If no callback is provided, it will return a * promise that will be resolved once the video element has loaded. */ regression(video, callback) { this.usageType = 'regressor'; if (video) { callCallback(this.loadVideo(video), callback); } return this; } async loadVideo(video) { let inputVideo = null; if (video instanceof HTMLVideoElement) { inputVideo = video; } else if (typeof video === 'object' && video.elt instanceof HTMLVideoElement) { inputVideo = video.elt; // p5.js video element } if (inputVideo) { const vid = new Video(inputVideo, IMAGE_SIZE); this.video = await vid.loadVideo(); } return this; } /** * Adds a new image element to Mobilenet * @param {HTMLVideoElement || p5.Video || String} inputOrLabel * @param {String || function} labelOrCallback * @param {function} cb */ async addImage(inputOrLabel, labelOrCallback, cb) { let imgToAdd; let label; let callback = cb; if (inputOrLabel instanceof HTMLImageElement || inputOrLabel instanceof HTMLVideoElement || inputOrLabel instanceof HTMLCanvasElement || inputOrLabel.elt instanceof ImageData) { imgToAdd = inputOrLabel; } else if (typeof inputOrLabel === 'object' && (inputOrLabel.elt instanceof HTMLImageElement || inputOrLabel.elt instanceof HTMLVideoElement || inputOrLabel.elt instanceof HTMLCanvasElement || inputOrLabel.elt instanceof ImageData)) { imgToAdd = inputOrLabel.elt; } else if (typeof inputOrLabel === 'string' || typeof inputOrLabel === 'number') { imgToAdd = this.video; label = inputOrLabel; } if (typeof labelOrCallback === 'string' || typeof labelOrCallback === 'number') { label = labelOrCallback; } else if (typeof labelOrCallback === 'function') { callback = labelOrCallback; } if (typeof label === 'string') { if (!this.mapStringToIndex.includes(label)) { label = this.mapStringToIndex.push(label) - 1; } else { label = this.mapStringToIndex.indexOf(label); } } return callCallback(this.addImageInternal(imgToAdd, label), callback); } async addImageInternal(imgToAdd, label) { await this.ready; tf.tidy(() => { const imageResize = (imgToAdd === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE]; const processedImg = imgToTensor(imgToAdd, imageResize); const prediction = this.mobilenetFeatures.predict(processedImg); let y; if (this.usageType === 'classifier') { y = tf.tidy(() => tf.oneHot(tf.tensor1d([label], 'int32'), this.config.numLabels)); } else if (this.usageType === 'regressor') { y = tf.tensor2d([[label]]); } if (this.xs == null) { this.xs = tf.keep(prediction); this.ys = tf.keep(y); this.hasAnyTrainedClass = true; } else { const oldX = this.xs; this.xs = tf.keep(oldX.concat(prediction, 0)); const oldY = this.ys; this.ys = tf.keep(oldY.concat(y, 0)); oldX.dispose(); oldY.dispose(); y.dispose(); } }); return this; } /** * Retrain the model with the provided images and labels using the * models original features as starting point. * @param {function} onProgress - A function to be called to follow * the progress of the training. */ async train(onProgress) { if (!this.hasAnyTrainedClass) { throw new Error('Add some examples before training!'); } this.isPredicting = false; if (this.usageType === 'classifier') { this.loss = 'categoricalCrossentropy'; this.customModel = tf.sequential({ layers: [ tf.layers.flatten({ inputShape: [7, 7, 256] }), tf.layers.dense({ units: this.config.hiddenUnits, activation: 'relu', kernelInitializer: 'varianceScaling', useBias: true, }), tf.layers.dense({ units: this.config.numLabels, kernelInitializer: 'varianceScaling', useBias: false, activation: 'softmax', }), ], }); } else if (this.usageType === 'regressor') { this.loss = 'meanSquaredError'; this.customModel = tf.sequential({ layers: [ tf.layers.flatten({ inputShape: [7, 7, 256] }), tf.layers.dense({ units: this.config.hiddenUnits, activation: 'relu', kernelInitializer: 'varianceScaling', useBias: true, }), tf.layers.dense({ units: 1, useBias: false, kernelInitializer: 'Zeros', activation: 'linear', }), ], }); } this.jointModel = tf.sequential(); this.jointModel.add(this.mobilenetFeatures); // mobilenet this.jointModel.add(this.customModel); // transfer layer const optimizer = tf.train.adam(this.config.learningRate); this.customModel.compile({ optimizer, loss: this.loss }); const batchSize = Math.floor(this.xs.shape[0] * this.config.batchSize); if (!(batchSize > 0)) { throw new Error('Batch size is 0 or NaN. Please choose a non-zero fraction.'); } return this.customModel.fit(this.xs, this.ys, { batchSize, epochs: this.config.epochs, callbacks: { onBatchEnd: async (batch, logs) => { onProgress(logs.loss.toFixed(5)); await tf.nextFrame(); }, onTrainEnd: () => onProgress(null), }, }); } /** * Classifies an an image based on a new retrained model. * .classification() needs to be used with this. * @param {HTMLVideoElement || p5.Video || function} inputOrCallback * @param {function} cb */ /* eslint max-len: ["error", { "code": 180 }] */ async classify(inputOrCallback, cb) { let imgToPredict; let callback; if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement || inputOrCallback instanceof HTMLCanvasElement || inputOrCallback instanceof ImageData) { imgToPredict = inputOrCallback; } else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement || inputOrCallback.elt instanceof HTMLCanvasElement || inputOrCallback.elt instanceof ImageData)) { imgToPredict = inputOrCallback.elt; // p5.js image element } else if (typeof inputOrCallback === 'function') { imgToPredict = this.video; callback = inputOrCallback; } if (typeof cb === 'function') { callback = cb; } return callCallback(this.classifyInternal(imgToPredict), callback); } async classifyInternal(imgToPredict) { if (this.usageType !== 'classifier') { throw new Error('Mobilenet Feature Extraction has not been set to be a classifier.'); } await tf.nextFrame(); this.isPredicting = true; const predictedClasses = tf.tidy(() => { const imageResize = (imgToPredict === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE]; const processedImg = imgToTensor(imgToPredict, imageResize); const predictions = this.jointModel.predict(processedImg); return Array.from(predictions.as1D().dataSync()); }); const results = await predictedClasses.map((confidence, index) => { const label = (this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]) ? this.mapStringToIndex[index] : index; return { label, confidence, }; }).sort((a, b) => b.confidence - a.confidence); return results; } /** * Predicts a continues values based on a new retrained model. * .regression() needs to be used with this. * @param {HTMLVideoElement || p5.Video || function} inputOrCallback * @param {function} cb */ /* eslint max-len: ["error", { "code": 180 }] */ async predict(inputOrCallback, cb) { let imgToPredict; let callback; if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement || inputOrCallback instanceof HTMLCanvasElement || inputOrCallback instanceof ImageData) { imgToPredict = inputOrCallback; } else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement || inputOrCallback.elt instanceof HTMLCanvasElement || inputOrCallback.elt instanceof ImageData)) { imgToPredict = inputOrCallback.elt; // p5.js image element } else if (typeof inputOrCallback === 'function') { imgToPredict = this.video; callback = inputOrCallback; } if (typeof cb === 'function') { callback = cb; } return callCallback(this.predictInternal(imgToPredict), callback); } async predictInternal(imgToPredict) { if (this.usageType !== 'regressor') { throw new Error('Mobilenet Feature Extraction has not been set to be a regressor.'); } await tf.nextFrame(); this.isPredicting = true; const predictedClass = tf.tidy(() => { const imageResize = (imgToPredict === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE]; const processedImg = imgToTensor(imgToPredict, imageResize); const predictions = this.jointModel.predict(processedImg); return predictions.as1D(); }); const prediction = await predictedClass.data(); predictedClass.dispose(); return { value: prediction[0] }; } async load(filesOrPath = null, callback) { if (typeof filesOrPath !== 'string') { let model = null; let weights = null; Array.from(filesOrPath).forEach((file) => { if (file.name.includes('.json')) { model = file; const fr = new FileReader(); fr.onload = (d) => { if (JSON.parse(d.target.result).ml5Specs) { this.mapStringToIndex = JSON.parse(d.target.result).ml5Specs.mapStringToIndex; } }; fr.readAsText(file); } else if (file.name.includes('.bin')) { weights = file; } }); this.jointModel = await tf.loadLayersModel(tf.io.browserFiles([model, weights])); } else { fetch(filesOrPath) .then(r => r.json()) .then((r) => { if (r.ml5Specs) { this.mapStringToIndex = r.ml5Specs.mapStringToIndex; } }); this.jointModel = await tf.loadLayersModel(filesOrPath); if (callback) { callback(); } } return this.jointModel; } async save(callback, name) { if (!this.jointModel) { throw new Error('No model found.'); } this.jointModel.save(tf.io.withSaveHandler(async (data) => { let modelName = 'model'; if(name) modelName = name; this.weightsManifest = { modelTopology: data.modelTopology, weightsManifest: [{ paths: [`./${modelName}.weights.bin`], weights: data.weightSpecs, }], ml5Specs: { mapStringToIndex: this.mapStringToIndex, }, }; await saveBlob(data.weightData, `${modelName}.weights.bin`, 'application/octet-stream'); await saveBlob(JSON.stringify(this.weightsManifest), `${modelName}.json`, 'text/plain'); if (callback) { callback(); } })); } mobilenetInfer(input, embedding=false) { let img = input; if (img instanceof tf.Tensor || img instanceof ImageData || img instanceof HTMLImageElement || img instanceof HTMLCanvasElement || img instanceof HTMLVideoElement ) { return tf.tidy(() => { if (!(img instanceof tf.Tensor)) { img = tf.browser.fromPixels(img); } const normalized = img.toFloat().sub(this.normalizationOffset) .div(this.normalizationOffset); // Resize the image to let resized = normalized; if (img.shape[0] !== IMAGE_SIZE || img.shape[1] !== IMAGE_SIZE) { const alignCorners = true; resized = tf.image.resizeBilinear( normalized, [IMAGE_SIZE, IMAGE_SIZE], alignCorners); } // Reshape so we can pass it to predict. const batched = resized.reshape([-1, IMAGE_SIZE, IMAGE_SIZE, 3]); let result; if (embedding) { const embeddingName = EMBEDDING_NODES[this.config.version]; const internal = this.model.execute(batched, embeddingName); result = internal.squeeze([1, 2]); } else { const logits1001 = this.model.predict(batched); result = logits1001.slice([0, 1], [-1, 1000]); } return result; } ); } return null; } infer(input, endpoint) { let imgToPredict; let endpointToPredict; if (input instanceof HTMLImageElement || input instanceof HTMLVideoElement || input instanceof HTMLCanvasElement || input instanceof ImageData) { imgToPredict = input; } else if (typeof input === 'object' && (input.elt instanceof HTMLImageElement || input.elt instanceof HTMLVideoElement || input.elt instanceof HTMLCanvasElement || input.elt instanceof ImageData)) { imgToPredict = input.elt; // p5.js image/canvas/video element } else { throw new Error('No input image found.'); } if (endpoint && typeof endpoint === 'string') { endpointToPredict = endpoint; } else { endpointToPredict = 'conv_preds'; } return this.mobilenetInfer(imgToPredict, endpointToPredict); } } export default Mobilenet;