ml5-save
Version:
237 lines (213 loc) • 7.74 kB
JavaScript
// Copyright (c) 2018 ml5
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
/*
* A K-nearest neighbors (KNN) classifier that allows fast
* custom model training on top of any tensor input. Useful for transfer
* learning with an embedding from another pretrained model.
*/
import * as tf from '@tensorflow/tfjs';
import * as knnClassifier from '@tensorflow-models/knn-classifier';
import * as io from '../utils/io';
import callCallback from '../utils/callcallback';
class KNN {
/**
* Create a KNNClassifier instance.
*/
constructor() {
this.knnClassifier = knnClassifier.create();
this.mapStringToIndex = [];
}
/**
* Adding an example to a class.
* @param {*} input - An example to add to the dataset, usually an activation from another model.
* @param {number || String} classIndexOrLabel The class index(number) or label(string) of the example.
*/
addExample(input, classIndexOrLabel) {
let classIndex;
let example;
if (typeof classIndexOrLabel === 'string') {
if (!this.mapStringToIndex.includes(classIndexOrLabel)) {
classIndex = this.mapStringToIndex.push(classIndexOrLabel) - 1;
} else {
classIndex = this.mapStringToIndex.indexOf(classIndexOrLabel);
}
} else if (typeof classIndexOrLabel === 'number') {
classIndex = classIndexOrLabel;
}
if (Array.isArray(input)) {
example = tf.tensor(input);
} else {
example = input;
}
this.knnClassifier.addExample(example, classIndex);
}
/**
* Classify an new input. It returns an object with a top classIndex and label, confidences mapping all class indices to their confidence, and confidencesByLabel mapping all classes' confidence by label.
* @param {*} input - An example to make a prediction on, could be an activation from another model or an array of numbers.
* @param {number} k - Optional. The K value to use in K-nearest neighbors. The algorithm will first find the K nearest examples from those it was previously shown, and then choose the class that appears the most as the final prediction for the input example. Defaults to 3. If examples < k, k = examples.
* @param {function} callback - Optional. A function to be called once the input has been classified. If no callback is provided, it will return a promise that will be resolved once the model has classified the new input.
*/
async classify(input, kOrCallback, cb) {
let k = 3;
let callback = cb;
let example;
if (typeof kOrCallback === 'number') {
k = kOrCallback;
} else if (typeof kOrCallback === 'function') {
callback = kOrCallback;
}
if (Array.isArray(input)) {
example = tf.tensor(input);
} else {
example = input;
}
return callCallback(this.classifyInternal(example, k), callback);
}
async classifyInternal(input, k) {
const numClass = this.knnClassifier.getNumClasses();
if (numClass <= 0) {
throw new Error('There is no example in any class');
} else {
const res = await this.knnClassifier.predictClass(input, k);
if (this.mapStringToIndex.length > 0) {
if (res.classIndex || res.classIndex === 0) {
const label = this.mapStringToIndex[res.classIndex];
if (label) res.label = label;
}
if (res.confidences) {
res.confidencesByLabel = {};
const { confidences } = res;
const indexes = Object.keys(confidences);
indexes.forEach((index) => {
const label = this.mapStringToIndex[index];
res.confidencesByLabel[label] = confidences[index];
});
}
}
return res;
}
}
/**
* Clear all examples in a label.
* @param {number||number} labelIndex - The class index or label, a number or a string.
*/
clearLabel(labelIndex) {
let classIndex;
if (typeof labelIndex === 'string') {
if (this.mapStringToIndex.includes(labelIndex)) {
classIndex = this.mapStringToIndex.indexOf(labelIndex);
}
} else if (typeof labelIndex === 'number') {
classIndex = labelIndex;
}
this.knnClassifier.clearClass(classIndex);
}
clearAllLabels() {
this.mapStringToIndex = [];
this.knnClassifier.clearAllClasses();
}
/**
* Get the example count for each label. It returns an object that maps class label to example count for each class.
* @returns {Number}
*/
getCountByLabel() {
const countByIndex = this.knnClassifier.getClassExampleCount();
if (this.mapStringToIndex.length > 0) {
const countByLabel = {};
Object.keys(countByIndex).forEach((key) => {
if (this.mapStringToIndex[key]) {
const label = this.mapStringToIndex[key];
countByLabel[label] = countByIndex[key];
}
});
return countByLabel;
}
return countByIndex;
}
/**
* Get the example count for each class. It returns an object that maps class index to example count for each class.
* @returns {Number}
*/
getCount() {
return this.knnClassifier.getClassExampleCount();
}
getClassifierDataset() {
return this.knnClassifier.getClassifierDataset();
}
setClassifierDataset(dataset) {
this.knnClassifier.setClassifierDataset(dataset);
}
/**
* It returns the total number of labels.
* @returns {String}
*/
getNumLabels() {
return this.knnClassifier.getNumClasses();
}
dispose() {
this.knnClassifier.dispose();
}
/**
* Download the whole dataset as a JSON file. It's useful for saving state.
* @param {String} name - Optional. The name of the JSON file that will be downloaded. e.g. "myKNN" or "myKNN.json". If no fileName is provided, the default file name is "myKNN.json".
*/
async save(name) {
const dataset = this.knnClassifier.getClassifierDataset();
if (this.mapStringToIndex.length > 0) {
Object.keys(dataset).forEach((key) => {
if (this.mapStringToIndex[key]) {
dataset[key].label = this.mapStringToIndex[key];
}
});
}
const tensors = Object.keys(dataset).map((key) => {
const t = dataset[key];
if (t) {
return t.dataSync();
}
return null;
});
let fileName = 'myKNN.json';
if (name) {
fileName = name.endsWith('.json') ? name : `${name}.json`;
}
await io.saveBlob(JSON.stringify({ dataset, tensors }), fileName, 'application/octet-stream');
}
/**
* Load a dataset from a JSON file. It's useful for restoring state.
* @param {String} pathOrData - The path for a valid JSON file.
* @param {function} callback - Optional. A function to run once the dataset has been loaded. If no callback is provided, it will return a promise that will be resolved once the dataset has loaded.
*/
async load(pathOrData, callback) {
let data;
if (typeof pathOrData === 'object') {
data = pathOrData;
} else {
data = await io.loadFile(pathOrData);
}
if (data) {
const { dataset, tensors } = data;
this.mapStringToIndex = Object.keys(dataset).map(key => dataset[key].label);
const tensorsData = tensors
.map((tensor, i) => {
if (tensor) {
const values = Object.keys(tensor).map(v => tensor[v]);
return tf.tensor(values, dataset[i].shape, dataset[i].dtype);
}
return null;
})
.reduce((acc, cur, j) => {
acc[j] = cur;
return acc;
}, {});
this.knnClassifier.setClassifierDataset(tensorsData);
if (callback) {
callback();
}
}
}
}
const KNNClassifier = () => new KNN();
export default KNNClassifier;