UNPKG

teachable-machine.js

Version:

A robust and optimized JavaScript library for integrating Google's Teachable Machine models, supporting various image sources and providing efficient classification capabilities.

147 lines (133 loc) 5.84 kB
import * as tf from '@tensorflow/tfjs'; import sharp from 'sharp'; import got from 'got'; import fs from 'fs/promises'; /** * Sets TensorFlow debug environment variable. */ tf.env().set('DEBUG', false); /** * Retrieves the top K classes from the prediction logits. * @param {tf.Tensor} logits - The output tensor from the model prediction. * @param {string[]} classes - An array of class labels. * @returns {Promise<Array<{class: string, score: number}>>} A promise that resolves to an array of top K classes with their scores. */ const getTopKClasses = async (logits, classes) => { const values = await logits.data(); logits.dispose(); const topK = Math.min(classes.length, values.length); return Array.from(values) .map((value, i) => ({ value, index: i })) .sort((a, b) => b.value - a.value) .slice(0, topK) .map(item => ({ class: classes[item.index], score: item.value, })); }; export default class TechableMachine { /** * Creates an instance of Techable Machine. * @param {tf.LayersModel} model - The loaded TensorFlow.js model. */ constructor(model) { this.model = model; } /** * Asynchronously creates and initializes an Techable Machine instance. * @param {object} options - Options for creating the instance. * @param {string} options.modelUrl - The base URL where the model.json and metadata.json files are located. * @returns {Promise<TechableMachine>} A promise that resolves to a new Techable Machine instance. * @throws {Error} If the model URL is missing, metadata cannot be loaded, or labels are invalid. */ static async create({ modelUrl }) { if (!modelUrl) { throw new Error("Model URL is missing!"); } try { const modelURL = `${modelUrl}model.json`; const metadataResponse = await got(`${modelUrl}metadata.json`).buffer(); const metadata = JSON.parse(metadataResponse.toString()); const model = await tf.loadLayersModel(modelURL); if (!metadata.labels || !Array.isArray(metadata.labels)) { throw new Error("Invalid metadata: 'labels' field not found or is not an array."); } model.classes = metadata.labels; return new TechableMachine(model); } catch (e) { throw new Error(`Model loading failed: ${e.message}`); } } /** * Retrieves the image buffer from a given image URL. * Supports data URI (Base64), HTTP/HTTPS URLs, and local file paths. * @param {string} imageUrl - The URL or path of the image. * @returns {Promise<Buffer>} A promise that resolves to the image buffer. * @throws {Error} If the image URL is invalid, the image cannot be downloaded, or the local file is not found. */ async _getImageBuffer(imageUrl) { if (imageUrl.startsWith('data:image/')) { const base64Data = imageUrl.split(',')[1]; if (!base64Data) { throw new Error("Invalid Base64 data URI: data part not found."); } return Buffer.from(base64Data, 'base64'); } if (imageUrl.startsWith('http')) { try { const response = await got(imageUrl).buffer(); return response; } catch (error) { throw new Error(`Failed to download image. Status: ${error.response ? error.response.statusCode : error.message}`); } } try { return await fs.readFile(imageUrl); } catch (error) { if (error.code === 'ENOENT') { throw new Error(`Local file not found: ${imageUrl}`); } throw error; } } /** * Decodes an image buffer and performs a prediction using the loaded model. * @param {Buffer} imageBuffer - The buffer containing the image data. * @returns {Promise<Array<{class: string, score: number}>>} A promise that resolves to an array of top K classes with their scores. */ async _decodeAndPredict(imageBuffer) { const { data, info } = await sharp(imageBuffer) .removeAlpha() .raw() .toBuffer({ resolveWithObject: true }); const imageTensor = tf.tensor3d(data, [info.height, info.width, 3], 'int32'); const logits = tf.tidy(() => { const resized = tf.image.resizeNearestNeighbor(imageTensor, [this.model.inputs[0].shape[1], this.model.inputs[0].shape[2]]); const offset = tf.scalar(127.5); const normalized = resized.toFloat().sub(offset).div(offset); const batched = normalized.expandDims(0); return this.model.predict(batched); }); imageTensor.dispose(); return await getTopKClasses(logits, this.model.classes); } /** * Classifies an image from a given URL. * @param {object} options - Options for classification. * @param {string} options.imageUrl - The URL or path of the image to classify. * @returns {Promise<Array<{class: string, score: number}>>} A promise that resolves to an array of top K classes with their scores. */ async classify({ imageUrl }) { const imageBuffer = await this._getImageBuffer(imageUrl); return this._decodeAndPredict(imageBuffer); } /** * Disposes of the loaded model and associated tensors from memory. * This should be called after each test or when the model is no longer needed. */ dispose() { if (this.model) { this.model.dispose(); } } }