UNPKG

ml5-save

Version:
75 lines (65 loc) 2.37 kB
// Copyright (c) 2018 ml5 // // This software is released under the MIT License. // https://opensource.org/licenses/MIT import * as tf from '@tensorflow/tfjs'; import { getTopKClassesFromTensor } from '../utils/gettopkclasses'; import DOODLENET_CLASSES from '../utils/DOODLENET_CLASSES'; const DEFAULTS = { DOODLENET_URL: 'https://cdn.jsdelivr.net/gh/ml5js/ml5-data-and-models@master/models/doodlenet/model.json', IMAGE_SIZE_DOODLENET: 28, }; function preProcess(img, size) { let image; if (!(img instanceof tf.Tensor)) { if (img instanceof HTMLImageElement || img instanceof HTMLVideoElement || img instanceof HTMLCanvasElement || img instanceof ImageData) { image = tf.browser.fromPixels(img); } else if (typeof img === 'object' && (img.elt instanceof HTMLImageElement || img.elt instanceof HTMLVideoElement || img.elt instanceof HTMLCanvasElement || img.elt instanceof ImageData)) { image = tf.browser.fromPixels(img.elt); // Handle p5.js image, video and canvas. } } else { image = img; } const normalized = tf.scalar(1).sub(image.toFloat().div(tf.scalar(255))); let resized = normalized; if (normalized.shape[0] !== size || normalized.shape[1] !== size) { resized = tf.image.resizeBilinear(normalized, [size, size]); } const [r, g, b] = tf.split(resized, 3, 3); const gray = (r.add(g).add(b)).div(tf.scalar(3)).floor(); // Get average r,g,b color value and round to 0 or 1 const batched = gray.reshape([1, size, size, 1]); return batched; } export class Doodlenet { constructor() { this.imgSize = DEFAULTS.IMAGE_SIZE_DOODLENET; } async load() { this.model = await tf.loadLayersModel(DEFAULTS.DOODLENET_URL); // Warmup the model. const result = tf.tidy(() => this.model.predict(tf.zeros([1, this.imgSize, this.imgSize, 1]))); await result.data(); result.dispose(); } async classify(img, topk = 10) { const logits = tf.tidy(() => { const imgData = preProcess(img, this.imgSize); const predictions = this.model.predict(imgData); return predictions; }); const classes = await getTopKClassesFromTensor(logits, topk, DOODLENET_CLASSES); logits.dispose(); return classes; } } export async function load() { const doodlenet = new Doodlenet(); await doodlenet.load(); return doodlenet; }