ml5-save
Version:
127 lines (103 loc) • 3.5 kB
JavaScript
// Copyright (c) 2018 ml5
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
const {
imageClassifier
} = ml5;
const TM_URL = 'https://storage.googleapis.com/tm-models/WfgKPytY/model.json';
const DEFAULTS = {
learningRate: 0.0001,
hiddenUnits: 100,
epochs: 20,
numClasses: 2,
batchSize: 0.4,
topk: 3,
alpha: 1,
version: 2,
};
async function getImage() {
const img = new Image();
img.crossOrigin = true;
img.src = 'https://cdn.jsdelivr.net/gh/ml5js/ml5-library@development/assets/bird.jpg';
await new Promise((resolve) => {
img.onload = resolve;
});
return img;
}
async function getCanvas() {
const img = await getImage();
const canvas = document.createElement('canvas');
canvas.width = img.width;
canvas.height = img.height;
canvas.getContext('2d').drawImage(img, 0, 0);
return canvas;
}
describe('imageClassifier', () => {
let classifier;
/**
* Test imageClassifier with teachable machine
*/
// Teachable machine model
describe('with Teachable Machine model', () => {
beforeAll(async () => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 15000;
classifier = await imageClassifier(TM_URL, undefined, {});
});
describe('instantiate', () => {
it('Should create a classifier with all the defaults', async () => {
expect(classifier.modelUrl).toBe(TM_URL);
});
});
});
/**
* Test imageClassifier with Mobilenet
*/
describe('imageClassifier with Mobilenet', () => {
beforeAll(async () => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 15000;
classifier = await imageClassifier('MobileNet', undefined, {});
});
describe('instantiate', () => {
it('Should create a classifier with all the defaults', async () => {
expect(classifier.version).toBe(DEFAULTS.version);
expect(classifier.alpha).toBe(DEFAULTS.alpha);
expect(classifier.topk).toBe(DEFAULTS.topk);
expect(classifier.ready).toBeTruthy();
});
})
describe('classify', () => {
it('Should classify an image of a Robin', async () => {
const img = await getImage();
await classifier.classify(img)
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
});
it('Should support p5 elements with an image on .elt', async () => {
const img = await getImage();
await classifier.classify({
elt: img
})
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
});
it('Should support HTMLCanvasElement', async () => {
const canvas = await getCanvas();
await classifier.classify(canvas)
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
});
it('Should support p5 elements with canvas on .canvas', async () => {
const canvas = await getCanvas();
await classifier.classify({
canvas
})
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
});
it('Should support p5 elements with canvas on .elt', async () => {
const canvas = await getCanvas();
await classifier.classify({
elt: canvas
})
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
});
});
});
})