teachable-machine.js
Version:
A robust and optimized JavaScript library for integrating Google's Teachable Machine models, supporting various image sources and providing efficient classification capabilities.
489 lines (457 loc) • 22.9 kB
JavaScript
import * as tf from '@tensorflow/tfjs';
import os from 'os';
import { http } from './utils/net.js';
import { dirExists, ioFromDir, readMetadata, writeMetadata } from './utils/io.js';
import { getImageBuffer, toSizedRGBTensor, setPreprocessOptions } from './preprocess.js';
import { ensureFFmpeg, ensureLocalPathWithCleanup, probeDurationSec, sampleTimestamps, safeTimestamps, extractFrames, getMediaBuffer, probeDurationSecFromBuffer, extractFramesFromBuffer, extractFramesAutoFromBuffer } from './utils/ffmpeg.js';
tf.env().set('DEBUG', false);
const getTopKClasses = async (logits, classes, topK) => {
const k = Math.min(topK ?? classes.length, classes.length);
const { values, indices } = tf.topk(logits, k);
const [v, i] = await Promise.all([values.data(), indices.data()]);
values.dispose(); indices.dispose(); logits.dispose();
return Array.from(i).map((idx, j) => ({ class: classes[idx], score: v[j] }));
};
export default class TeachableMachine {
constructor(model) {
this.model = model;
}
static async create({ modelUrl, modelDir, loadFrom = 'auto', saveToDir, warmup = true, ioMode = 'ram', backend = 'tfjs', preprocessUseWorkers = false } = {}) {
try {
if (backend === 'tfjs-node') {
try {
// Dynamically register tfjs-node backend if available
await import("@tensorflow/tfjs-node");
// Prefer native backend when present
if (tf.getBackend() !== 'tensorflow') {
await tf.setBackend('tensorflow');
await tf.ready();
}
} catch (e) {
throw new Error("Requested backend 'tfjs-node' but '@tensorflow/tfjs-node' is not installed. Install it with: npm i @tensorflow/tfjs-node");
}
}
let model; let classes;
if (loadFrom === 'dir' || (loadFrom === 'auto' && modelDir && await dirExists(modelDir))) {
model = await tf.loadLayersModel(ioFromDir(modelDir));
const metadata = await readMetadata(modelDir);
if (!metadata.labels || !Array.isArray(metadata.labels)) throw new Error('Invalid metadata in local dir.');
classes = metadata.labels;
} else {
if (!modelUrl) throw new Error('Model URL is missing!');
const modelURL = `${modelUrl}model.json`;
const metadataResponse = await http(`${modelUrl}metadata.json`).buffer();
const metadata = JSON.parse(metadataResponse.toString());
if (!metadata.labels || !Array.isArray(metadata.labels)) throw new Error("Invalid metadata: 'labels' field not found or is not an array.");
classes = metadata.labels;
model = await tf.loadLayersModel(modelURL);
const targetDir = saveToDir ?? modelDir;
if (targetDir) {
await model.save(ioFromDir(targetDir));
await writeMetadata(targetDir, { labels: classes });
}
}
model.classes = classes;
if (warmup && model.inputs?.[0]?.shape) {
const h = model.inputs[0].shape[1];
const w = model.inputs[0].shape[2];
if (typeof h === 'number' && typeof w === 'number') {
const dummy = tf.zeros([1, h, w, 3]);
const out = model.predict(dummy);
if (Array.isArray(out)) out.forEach(t => t.dispose()); else out.dispose();
dummy.dispose();
}
}
const tm = new TeachableMachine(model);
tm.ioMode = ioMode === 'disk' ? 'disk' : 'ram';
tm.backend = backend === 'tfjs-node' ? 'tfjs-node' : 'tfjs';
// optional: enable worker-threaded preprocessing
try { setPreprocessOptions({ useWorkers: !!preprocessUseWorkers }); } catch {}
return tm;
} catch (e) {
throw new Error(`Model loading failed: ${e.message}`);
}
}
async _decodeAndPredict(imageBuffer, { topK, centerCrop = true, resizeOnCPU = true } = {}) {
const inShape = this.model?.inputs?.[0]?.shape;
const targetH = inShape?.[1];
const targetW = inShape?.[2];
if (typeof targetH !== 'number' || typeof targetW !== 'number') throw new Error('Model input shape is not fully defined.');
const t0 = Date.now();
const sized = await toSizedRGBTensor(imageBuffer, targetW, targetH, { centerCrop });
const imageTensor = tf.tensor3d(sized.data, [targetH, targetW, 3], 'int32');
const t1 = Date.now();
const logits = tf.tidy(() => {
const offset = tf.scalar(127.5);
const normalized = imageTensor.toFloat().sub(offset).div(offset);
const batched = normalized.expandDims(0);
return this.model.predict(batched);
});
const t2 = Date.now();
imageTensor.dispose();
const top = await getTopKClasses(logits, this.model.classes, topK);
const t3 = Date.now();
return {
backend: tf.getBackend(),
modelInfo: { classesCount: this.model.classes.length },
preprocess: { target: { width: targetW, height: targetH }, centerCrop, resizeOnCPU },
timings: { decodeResizeMs: t1 - t0, inferenceMs: t2 - t1, postprocessMs: t3 - t2, totalMs: t3 - t0 },
predictions: top.map((p, idx) => ({ ...p, rank: idx + 1 }))
};
}
async _classifyImage({ imageUrl, topK, centerCrop = true, resizeOnCPU = true }) {
const tStart = Date.now();
const imageBuffer = await getImageBuffer(imageUrl);
const downloadEnd = Date.now();
const inner = await this._decodeAndPredict(imageBuffer, { topK, centerCrop, resizeOnCPU });
return {
input: { imageUrl },
...inner,
timings: { downloadMs: downloadEnd - tStart, ...inner.timings, endToEndMs: (downloadEnd - tStart) + inner.timings.totalMs }
};
}
async classifyImages({ images, topK, centerCrop = true, resizeOnCPU = true, batchSize } = {}) {
if (!images) throw new Error('images is required');
if (Array.isArray(images)) {
return this.classifyBatch({ imageUrls: images, topK, centerCrop, resizeOnCPU, batchSize });
}
return this._classifyImage({ imageUrl: images, topK, centerCrop, resizeOnCPU });
}
async classifyBatch({ imageUrls, topK, centerCrop = true, resizeOnCPU = true, batchSize }) {
if (!Array.isArray(imageUrls) || imageUrls.length === 0) throw new Error('imageUrls must be a non-empty array');
const inShape = this.model?.inputs?.[0]?.shape;
const targetH = inShape?.[1];
const targetW = inShape?.[2];
if (typeof targetH !== 'number' || typeof targetW !== 'number') throw new Error('Model input shape is not fully defined.');
const results = [];
const tBatchStart = Date.now();
const processChunk = async (urls) => {
const t0 = Date.now();
const dlResults = await Promise.all(urls.map(async (u) => {
try {
const buf = await getImageBuffer(u);
return { ok: true, buf };
} catch (e) {
return { ok: false, err: e, url: u };
}
}));
const okPairs = dlResults.map((r, i) => ({ idx: i, r })).filter(x => x.r.ok);
const failPairs = dlResults.map((r, i) => ({ idx: i, r })).filter(x => !x.r.ok);
// Immediately record failures so batch remains responsive
for (const { idx, r } of failPairs) {
results.push({
input: { imageUrl: urls[idx] },
error: r.err?.message || String(r.err),
backend: tf.getBackend(),
modelInfo: { classesCount: this.model.classes.length },
timings: { downloadMs: 0, decodeResizeMs: 0, inferenceMs: 0, postprocessMs: 0, totalMs: 0 }
});
}
const buffers = okPairs.map(p => p.r.buf);
const tDownloadEnd = Date.now();
const tensors = [];
for (const buf of buffers) {
const sized = await toSizedRGBTensor(buf, targetW, targetH, { centerCrop });
tensors.push(tf.tensor3d(sized.data, [targetH, targetW, 3], 'int32'));
}
const tPrepEnd = Date.now();
const logits = tf.tidy(() => {
const offset = tf.scalar(127.5);
const batch = tf.stack(tensors.map(t => t.toFloat().sub(offset).div(offset)));
const out = this.model.predict(batch);
return Array.isArray(out) ? out[0] : out;
});
tensors.forEach(t => t.dispose());
const tInferEnd = Date.now();
const k = Math.min(topK ?? this.model.classes.length, this.model.classes.length);
const { values, indices } = tf.topk(logits, k);
const [vals, inds] = await Promise.all([values.array(), indices.array()]);
values.dispose(); indices.dispose(); logits.dispose();
const tPostEnd = Date.now();
// Map predictions back onto successful indices only
okPairs.forEach(({ idx: okIdx }, row) => {
const u = urls[okIdx];
const preds = inds[row].map((clsIdx, j) => ({ class: this.model.classes[clsIdx], score: vals[row][j], rank: j + 1 }));
results.push({
input: { imageUrl: u },
backend: tf.getBackend(),
modelInfo: { classesCount: this.model.classes.length },
preprocess: { target: { width: targetW, height: targetH }, centerCrop, resizeOnCPU: true },
timings: { downloadMs: tDownloadEnd - t0, decodeResizeMs: tPrepEnd - tDownloadEnd, inferenceMs: tInferEnd - tPrepEnd, postprocessMs: tPostEnd - tInferEnd, totalMs: tPostEnd - t0 },
predictions: preds
});
});
};
if (batchSize && batchSize > 0 && batchSize < imageUrls.length) {
for (let i = 0; i < imageUrls.length; i += batchSize) {
const chunk = imageUrls.slice(i, i + batchSize);
await processChunk(chunk);
}
} else {
await processChunk(imageUrls);
}
const tBatchEnd = Date.now();
return { backend: tf.getBackend(), count: imageUrls.length, modelInfo: { classesCount: this.model.classes.length }, timings: { endToEndMs: tBatchEnd - tBatchStart }, results };
}
/**
* Backward-compat alias for older code.
* Delegates to classifyBatch().
*/
async batchImageClassify(options) {
return this.classifyBatch(options);
}
/**
* Unified classify entry. Routes to image or video classification.
* @param {object} options
* @param {any|any[]} options.input - Image(s) or video(s)
* @param {'auto'|'image'|'video'} [options.mediaType='auto']
* @param {number} [options.frames=10]
*/
async classify({ input, mediaType = 'auto', frames = 10, topK, centerCrop = true, resizeOnCPU = true, turboMode = false, extractionConcurrency, preprocessConcurrency, maxConcurrent = 2, maxBytes = 10 * 1024 * 1024, batchSize } = {}) {
// Mixed object form: { images: [...], videos: [...] }
if (input && typeof input === 'object' && !Array.isArray(input) && (input.images || input.videos)) {
const tasks = [];
if (input.images && input.images.length) {
tasks.push(this.classifyImages({ images: input.images, topK, centerCrop, resizeOnCPU, batchSize }));
} else {
tasks.push(Promise.resolve(null));
}
if (input.videos && input.videos.length) {
tasks.push(this.classifyVideos({ videos: input.videos, frames, topK, centerCrop, resizeOnCPU, turboMode, extractionConcurrency, preprocessConcurrency, maxConcurrent, maxBytes }));
} else {
tasks.push(Promise.resolve(null));
}
const [imagesRes, videosRes] = await Promise.all(tasks);
return { images: imagesRes, videos: videosRes };
}
// Legacy/array form: decide route
const isVideo = mediaType === 'video' || (Number.isFinite(frames) && frames > 0 && mediaType !== 'image');
if (isVideo) {
return this.classifyVideos({ videos: input, frames, topK, centerCrop, resizeOnCPU, turboMode, extractionConcurrency, preprocessConcurrency, maxConcurrent, maxBytes });
}
return this.classifyImages({ images: input, topK, centerCrop, resizeOnCPU, batchSize });
}
/**
* Public wrapper for single or multiple videos.
*/
async classifyVideos({ videos, ...rest } = {}) {
return this.classifyVideo({ videoUrl: videos, ...rest });
}
dispose() { if (this.model) this.model.dispose(); }
/**
* Classifies one or more videos/GIFs by sampling evenly spaced frames with FFmpeg and running the image pipeline.
* Accepts URL/path, Buffer/Uint8Array, data URI or base64 for single input, or an array of such inputs.
* @param {object} options
* @param {string|Buffer|Uint8Array|(string|Buffer|Uint8Array)[]} options.videoUrl
* @param {number} [options.frames=10]
* @param {number} [options.topK]
* @param {boolean} [options.centerCrop=true]
* @param {boolean} [options.resizeOnCPU=true]
* @param {boolean} [options.turboMode=false]
*/
async classifyVideo({ videoUrl, frames = 10, topK, centerCrop = true, resizeOnCPU = true, turboMode = false, extractionConcurrency, preprocessConcurrency, maxConcurrent = 2, maxBytes = 10 * 1024 * 1024 } = {}) {
// Support single or multiple inputs
if (Array.isArray(videoUrl)) {
const tasks = videoUrl.map((u) => async () => this.classifyVideo({ videoUrl: u, frames, topK, centerCrop, resizeOnCPU, turboMode, extractionConcurrency, preprocessConcurrency, maxBytes }));
const q = tasks.slice();
const out = [];
const runners = new Array(Math.max(1, Math.min(maxConcurrent, q.length))).fill(0).map(async function run() {
while (q.length) {
const job = q.shift();
if (!job) break;
out.push(await job());
}
});
await Promise.all(runners);
return out;
}
if (!videoUrl) throw new Error('videoUrl is required');
if (!Number.isFinite(frames) || frames <= 0) throw new Error('frames must be a positive number');
const tStart = Date.now();
// Determine target model input size
const inShape = this.model?.inputs?.[0]?.shape;
const targetH = inShape?.[1];
const targetW = inShape?.[2];
if (typeof targetH !== 'number' || typeof targetW !== 'number') throw new Error('Model input shape is not fully defined.');
const aggregateScores = new Array(this.model.classes.length).fill(0);
let frameCount = 0;
const ffmpegPath = await ensureFFmpeg();
let cleanup = async () => {};
let durationSec; let framesSource; let usedMode = this.ioMode; let fallbackToDisk = false; let sizeBytes = 0; let tempCleaned = false;
try {
if (this.ioMode === 'ram') {
const mediaBuf = await getMediaBuffer(videoUrl);
sizeBytes = mediaBuf.length;
if (maxBytes && sizeBytes > maxBytes) throw new Error(`Media exceeds maxBytes (${sizeBytes} > ${maxBytes})`);
durationSec = await probeDurationSecFromBuffer(ffmpegPath, mediaBuf);
framesSource = mediaBuf;
} else {
const loc = await ensureLocalPathWithCleanup(videoUrl);
cleanup = loc.cleanup;
try {
const { default: fs } = await import('fs/promises');
const st = await fs.stat(loc.path).catch(() => null);
sizeBytes = st?.size ?? 0;
if (maxBytes && sizeBytes > maxBytes) throw new Error(`Media exceeds maxBytes (${sizeBytes} > ${maxBytes})`);
} catch {}
durationSec = await probeDurationSec(ffmpegPath, loc.path);
framesSource = loc.path;
}
if (!durationSec || durationSec <= 0) throw new Error('Unable to determine video duration');
const stampsRaw = sampleTimestamps(durationSec, Math.floor(frames));
const stamps = safeTimestamps(durationSec, stampsRaw);
const tPrepEnd = Date.now();
const cpuCount = (typeof os?.cpus === 'function' && Array.isArray(os.cpus())) ? os.cpus().length : 4;
const extractConc = Math.max(1, Math.min(16, extractionConcurrency ?? (turboMode ? Math.min(8, cpuCount) : 1)));
let frameBuffers = usedMode === 'ram'
? await extractFramesAutoFromBuffer(ffmpegPath, framesSource, Math.floor(frames), durationSec)
: await extractFrames(ffmpegPath, framesSource, stamps, { concurrency: extractConc });
if (usedMode === 'ram' && (!frameBuffers || frameBuffers.length === 0)) {
// Fallback to disk
const loc = await ensureLocalPathWithCleanup(videoUrl);
fallbackToDisk = true; usedMode = 'disk';
try {
frameBuffers = await extractFrames(ffmpegPath, loc.path, stamps, { concurrency: extractConc });
const { default: fs } = await import('fs/promises');
const st = await fs.stat(loc.path).catch(() => null);
sizeBytes = st?.size ?? sizeBytes;
} finally {
await loc.cleanup();
}
}
if (turboMode) {
const t0 = Date.now();
const prepConc = Math.max(1, Math.min(32, preprocessConcurrency ?? Math.min(8, cpuCount)));
const jobs = frameBuffers.map((buf, i) => async () => {
const sized = await toSizedRGBTensor(buf, targetW, targetH, { centerCrop });
return tf.tensor3d(sized.data, [targetH, targetW, 3], 'int32');
});
const queue = jobs.slice();
const tensors = new Array(jobs.length);
const runners = new Array(Math.max(1, Math.min(prepConc, queue.length))).fill(0).map(async function run() {
while (queue.length) {
const idx = jobs.length - queue.length;
const job = queue.shift();
if (!job) break;
tensors[idx] = await job();
}
});
await Promise.all(runners);
const t1 = Date.now();
const logits = tf.tidy(() => {
const offset = tf.scalar(127.5);
const batch = tf.stack(tensors.map(t => t.toFloat().sub(offset).div(offset)));
const out = this.model.predict(batch);
return Array.isArray(out) ? out[0] : out;
});
tensors.forEach(t => t.dispose());
const t2 = Date.now();
const [probs] = await Promise.all([logits.array()]);
logits.dispose();
frameCount = probs.length;
for (let f = 0; f < probs.length; f++) {
for (let c = 0; c < probs[f].length; c++) aggregateScores[c] += probs[f][c];
}
const k = Math.min(topK ?? this.model.classes.length, this.model.classes.length);
const results = [];
for (let i = 0; i < probs.length; i++) {
const scores = probs[i];
const idxs = scores.map((s, ci) => ci).sort((a, b) => scores[b] - scores[a]).slice(0, k);
const preds = idxs.map((ci, j) => ({ class: this.model.classes[ci], score: scores[ci], rank: j + 1 }));
results.push({ frameIndex: i, timestampSec: stamps[i] ?? null, predictions: preds });
}
const t3 = Date.now();
const tEnd = Date.now();
const avg = aggregateScores.map(s => s / Math.max(1, frameCount));
const order = avg.map((s, i) => i).sort((a, b) => avg[b] - avg[a]);
const overall = order.map((ci, j) => ({ class: this.model.classes[ci], score: avg[ci], rank: j + 1 })).slice(0, k);
const out = {
input: { videoUrl, frames: Math.floor(frames), turboMode: true },
backend: tf.getBackend(),
modelInfo: { classesCount: this.model.classes.length },
timings: {
downloadPrepareMs: tPrepEnd - tStart,
decodeResizeMs: t1 - t0,
inferenceMs: t2 - t1,
postprocessMs: t3 - t2,
totalMs: tEnd - tStart
},
io: { mode: usedMode, fallbackToDisk, tempCleaned, sizeBytes, maxBytes },
results,
aggregate: { predictions: overall }
};
await cleanup(); tempCleaned = true;
return out;
} else {
const results = [];
let decodeResizeMs = 0; let inferenceMs = 0; let postprocessMs = 0;
for (let i = 0; i < frameBuffers.length; i++) {
const tA = Date.now();
const sized = await toSizedRGBTensor(frameBuffers[i], targetW, targetH, { centerCrop });
const imageTensor = tf.tensor3d(sized.data, [targetH, targetW, 3], 'int32');
const tB = Date.now();
const logits = tf.tidy(() => {
const offset = tf.scalar(127.5);
const norm = imageTensor.toFloat().sub(offset).div(offset).expandDims(0);
const out = this.model.predict(norm);
return Array.isArray(out) ? out[0] : out;
});
const tC = Date.now();
imageTensor.dispose();
const probs = await logits.array();
logits.dispose();
frameCount += 1;
for (let c = 0; c < probs[0].length; c++) aggregateScores[c] += probs[0][c];
const k = Math.min(topK ?? this.model.classes.length, this.model.classes.length);
const scores = probs[0];
const idxs = scores.map((s, ci) => ci).sort((a, b) => scores[b] - scores[a]).slice(0, k);
const preds = idxs.map((ci, j) => ({ class: this.model.classes[ci], score: scores[ci], rank: j + 1 }));
const tD = Date.now();
results.push({ frameIndex: i, timestampSec: stamps[i] ?? null, predictions: preds });
decodeResizeMs += tB - tA; inferenceMs += tC - tB; postprocessMs += tD - tC;
}
const tEnd = Date.now();
const avg = aggregateScores.map(s => s / Math.max(1, frameCount));
const order = avg.map((s, i) => i).sort((a, b) => avg[b] - avg[a]);
const overall = order.map((ci, j) => ({ class: this.model.classes[ci], score: avg[ci], rank: j + 1 })).slice(0, Math.min(topK ?? this.model.classes.length, this.model.classes.length));
const out = {
input: { videoUrl, frames: Math.floor(frames), turboMode: false },
backend: tf.getBackend(),
modelInfo: { classesCount: this.model.classes.length },
timings: {
downloadPrepareMs: tPrepEnd - tStart,
decodeResizeMs,
inferenceMs,
postprocessMs,
totalMs: tEnd - tStart
},
io: { mode: usedMode, fallbackToDisk, tempCleaned, sizeBytes, maxBytes },
results,
aggregate: { predictions: overall }
};
await cleanup(); tempCleaned = true;
return out;
}
} finally {
try { await cleanup(); tempCleaned = true; } catch {}
}
}
/**
* Classifies multiple videos or GIFs. Processes videos sequentially by default to limit memory.
* @param {object} options
* @param {string[]} options.videoUrls
* @param {number} [options.frames=10]
* @param {number} [options.topK]
* @param {boolean} [options.centerCrop=true]
* @param {boolean} [options.resizeOnCPU=true]
* @param {boolean} [options.turboMode=false]
* @param {number} [options.extractionConcurrency]
* @param {number} [options.preprocessConcurrency]
* @param {number} [options.maxConcurrent=2]
* @returns {Promise<object>} Batch result with per-video outputs.
*/
async classifyVideoBatch({ videoUrls, ...rest }) {
if (!Array.isArray(videoUrls) || videoUrls.length === 0) throw new Error('videoUrls must be a non-empty array');
return this.classifyVideo({ videoUrl: videoUrls, ...rest });
}
}