tractjs
Version:
A library for running ONNX and TensorFlow inference in the browser.
128 lines (118 loc) • 3.85 kB
text/typescript
import { Model, Tensor, load, terminate } from "../dist/tractjs.js";
describe('model', () => {
afterAll((done) => {
terminate();
// TODO: Figure out how to detect when the worker is terminated
// https://stackoverflow.com/questions/33044817/how-can-i-know-if-a-web-worker-has-closed
setTimeout(() => done(), 2000)
});
test('cannot be created directly', () => {
// @ts-ignore: Constructor of class 'Model' is private.
expect(() => new Model()).toThrow();
});
test('fails gracefully when loaded incorrectly', async () => {
await expect(load('./tests/plus3.pb')).rejects.toThrow();
await expect(load('./tests/plus3', { optimize: false })).rejects.toThrow();
});
test('can load a model', async () => {
await expect(load('./tests/plus3.pb', { optimize: false })).resolves.toBeInstanceOf(Model)
await expect(load('./tests/plus3', { format: 'tensorflow', optimize: false })).resolves.toBeInstanceOf(Model)
});
test('can predict on a single input tensor', async () => {
const model: Model = await load('./tests/plus3.pb', {
inputFacts: {
0: ['float32', [1, 3]],
},
});
const input = new Tensor(new Float32Array([1, 2, 3]), [1, 3]);
const prediction = await model.predict_one(input);
expect(Array.from(prediction.data)).toEqual([4, 5, 6]);
});
test('can predict on multiple input tensors', async () => {
const model: Model = await load('./tests/plus3.pb', {
inputFacts: {
0: ['float32', [1, 3]],
},
});
const input = new Tensor(new Float32Array([1, 2, 3]), [1, 3]);
const predictions = await model.predict([input]);
expect(Array.from(predictions[0].data)).toEqual([4, 5, 6]);
});
test('can be destroyed', async () => {
const model: Model = await load('./tests/plus3.pb', {
inputFacts: {
0: ['float32', [1, 3]],
},
});
await model.destroy();
const input = new Tensor(new Float32Array([1, 2, 3]), [1, 3]);
await expect(model.predict([input])).rejects.toThrow();
});
test('can predict with dynamic dimension', async () => {
const model: Model = await load('./tests/plus3.pb', {
inputFacts: {
0: ['float32', [1, 's']],
},
});
const input = new Tensor(new Float32Array([1, 2, 3, 4]), [1, 4]);
const predictions = await model.predict([input]);
expect(Array.from(predictions[0].data)).toEqual([4, 5, 6, 7])
});
test('can predict with dynamic dimension (and dimension arithmetic)', async () => {
const model: Model = await load('./tests/model.onnx', {
inputFacts: {
0: ['uint8', [1, {
id: 's',
slope: 2,
intercept: 0
}]],
},
});
const input = new Tensor(new Uint8Array([1, 2, 3, 4]), [1, 4]);
const predictions = await model.predict([input]);
expect(predictions[0].shape).toEqual([1, 4, 2])
});
test('can access model metadata', async () => {
const model: Model = await load('./tests/model.onnx', {
inputFacts: {
0: ['uint8', [1, {
id: 's',
slope: 2,
intercept: 0
}]],
},
});
expect(await model.get_metadata()).toEqual({
"split_sequence": `{
"instructions": [
[
"Sentence",
{
"PredictionIndex": 0
}
],
[
"Token",
{
"PredictionIndex": 1
}
],
[
"_Whitespace",
{
"Function": "whitespace"
}
]
]
}`
})
});
test('returns empty metadata for TF models', async () => {
const model: Model = await load('./tests/plus3.pb', {
inputFacts: {
0: ['float32', [1, 's']],
},
});
expect(await model.get_metadata()).toEqual({})
});
});