teachable-machine.js
Version:
A robust and optimized JavaScript library for integrating Google's Teachable Machine models, supporting various image sources and providing efficient classification capabilities.
147 lines (133 loc) • 5.84 kB
JavaScript
import * as tf from '@tensorflow/tfjs';
import sharp from 'sharp';
import got from 'got';
import fs from 'fs/promises';
/**
* Sets TensorFlow debug environment variable.
*/
tf.env().set('DEBUG', false);
/**
* Retrieves the top K classes from the prediction logits.
* @param {tf.Tensor} logits - The output tensor from the model prediction.
* @param {string[]} classes - An array of class labels.
* @returns {Promise<Array<{class: string, score: number}>>} A promise that resolves to an array of top K classes with their scores.
*/
const getTopKClasses = async (logits, classes) => {
const values = await logits.data();
logits.dispose();
const topK = Math.min(classes.length, values.length);
return Array.from(values)
.map((value, i) => ({ value, index: i }))
.sort((a, b) => b.value - a.value)
.slice(0, topK)
.map(item => ({
class: classes[item.index],
score: item.value,
}));
};
export default class TechableMachine {
/**
* Creates an instance of Techable Machine.
* @param {tf.LayersModel} model - The loaded TensorFlow.js model.
*/
constructor(model) {
this.model = model;
}
/**
* Asynchronously creates and initializes an Techable Machine instance.
* @param {object} options - Options for creating the instance.
* @param {string} options.modelUrl - The base URL where the model.json and metadata.json files are located.
* @returns {Promise<TechableMachine>} A promise that resolves to a new Techable Machine instance.
* @throws {Error} If the model URL is missing, metadata cannot be loaded, or labels are invalid.
*/
static async create({ modelUrl }) {
if (!modelUrl) {
throw new Error("Model URL is missing!");
}
try {
const modelURL = `${modelUrl}model.json`;
const metadataResponse = await got(`${modelUrl}metadata.json`).buffer();
const metadata = JSON.parse(metadataResponse.toString());
const model = await tf.loadLayersModel(modelURL);
if (!metadata.labels || !Array.isArray(metadata.labels)) {
throw new Error("Invalid metadata: 'labels' field not found or is not an array.");
}
model.classes = metadata.labels;
return new TechableMachine(model);
} catch (e) {
throw new Error(`Model loading failed: ${e.message}`);
}
}
/**
* Retrieves the image buffer from a given image URL.
* Supports data URI (Base64), HTTP/HTTPS URLs, and local file paths.
* @param {string} imageUrl - The URL or path of the image.
* @returns {Promise<Buffer>} A promise that resolves to the image buffer.
* @throws {Error} If the image URL is invalid, the image cannot be downloaded, or the local file is not found.
*/
async _getImageBuffer(imageUrl) {
if (imageUrl.startsWith('data:image/')) {
const base64Data = imageUrl.split(',')[1];
if (!base64Data) {
throw new Error("Invalid Base64 data URI: data part not found.");
}
return Buffer.from(base64Data, 'base64');
}
if (imageUrl.startsWith('http')) {
try {
const response = await got(imageUrl).buffer();
return response;
} catch (error) {
throw new Error(`Failed to download image. Status: ${error.response ? error.response.statusCode : error.message}`);
}
}
try {
return await fs.readFile(imageUrl);
} catch (error) {
if (error.code === 'ENOENT') {
throw new Error(`Local file not found: ${imageUrl}`);
}
throw error;
}
}
/**
* Decodes an image buffer and performs a prediction using the loaded model.
* @param {Buffer} imageBuffer - The buffer containing the image data.
* @returns {Promise<Array<{class: string, score: number}>>} A promise that resolves to an array of top K classes with their scores.
*/
async _decodeAndPredict(imageBuffer) {
const { data, info } = await sharp(imageBuffer)
.removeAlpha()
.raw()
.toBuffer({ resolveWithObject: true });
const imageTensor = tf.tensor3d(data, [info.height, info.width, 3], 'int32');
const logits = tf.tidy(() => {
const resized = tf.image.resizeNearestNeighbor(imageTensor, [this.model.inputs[0].shape[1], this.model.inputs[0].shape[2]]);
const offset = tf.scalar(127.5);
const normalized = resized.toFloat().sub(offset).div(offset);
const batched = normalized.expandDims(0);
return this.model.predict(batched);
});
imageTensor.dispose();
return await getTopKClasses(logits, this.model.classes);
}
/**
* Classifies an image from a given URL.
* @param {object} options - Options for classification.
* @param {string} options.imageUrl - The URL or path of the image to classify.
* @returns {Promise<Array<{class: string, score: number}>>} A promise that resolves to an array of top K classes with their scores.
*/
async classify({ imageUrl }) {
const imageBuffer = await this._getImageBuffer(imageUrl);
return this._decodeAndPredict(imageBuffer);
}
/**
* Disposes of the loaded model and associated tensors from memory.
* This should be called after each test or when the model is no longer needed.
*/
dispose() {
if (this.model) {
this.model.dispose();
}
}
}