@tensorflow/tfjs-node
Version:
This repository provides native TensorFlow execution in backend JavaScript applications under the Node.js runtime, accelerated by the TensorFlow C binary under the hood. It provides the same API as [TensorFlow.js](https://js.tensorflow.org/api/latest/).
256 lines (255 loc) • 11.7 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
var path = require("path");
// tslint:disable-next-line:no-require-imports
var binary = require('@mapbox/node-pre-gyp');
var bindingPath = binary.find(path.resolve(path.join(__dirname, '../package.json')));
// tslint:disable-next-line:no-require-imports
var bindings = require(bindingPath);
var binding = bindings;
describe('Exposes TF_DataType enum values', function () {
it('contains TF_FLOAT', function () {
expect(binding.TF_FLOAT).toEqual(1);
});
it('contains TF_INT32', function () {
expect(binding.TF_INT32).toEqual(3);
});
it('contains TF_BOOL', function () {
expect(binding.TF_BOOL).toEqual(10);
});
it('contains TF_COMPLEX64', function () {
expect(binding.TF_COMPLEX64).toEqual(8);
});
it('contains TF_STRING', function () {
expect(binding.TF_STRING).toEqual(7);
});
});
describe('Exposes TF_AttrType enum values', function () {
it('contains TF_ATTR_STRING', function () {
expect(binding.TF_ATTR_STRING).toEqual(0);
});
it('contains TF_ATTR_INT', function () {
expect(binding.TF_ATTR_INT).toEqual(1);
});
it('contains TF_ATTR_FLOAT', function () {
expect(binding.TF_ATTR_FLOAT).toEqual(2);
});
it('contains TF_ATTR_BOOL', function () {
expect(binding.TF_ATTR_BOOL).toEqual(3);
});
it('contains TF_ATTR_TYPE', function () {
expect(binding.TF_ATTR_TYPE).toEqual(4);
});
it('contains TF_ATTR_SHAPE', function () {
expect(binding.TF_ATTR_SHAPE).toEqual(5);
});
});
describe('Exposes TF Version', function () {
it('contains a version string', function () {
expect(binding.TF_Version).toBeDefined();
});
});
describe('tensor management', function () {
it('Creates and deletes a valid tensor', function () {
var values = new Int32Array([1, 2]);
var id = binding.createTensor([2], binding.TF_INT32, values);
expect(id).toBeDefined();
binding.deleteTensor(id);
});
it('throws exception when shape does not match data', function () {
expect(function () {
binding.createTensor([2], binding.TF_INT32, new Int32Array([1, 2, 3]));
}).toThrowError();
expect(function () {
binding.createTensor([4], binding.TF_INT32, new Int32Array([1, 2, 3]));
}).toThrowError();
});
it('throws exception with invalid dtype', function () {
expect(function () {
// tslint:disable-next-line:no-unused-expression
binding.createTensor([1], 1000, new Int32Array([1]));
}).toThrowError();
});
it('works with 0-dim tensors', function () {
// Reduce op (e.g 'Max') will produce a 0-dim TFE_Tensor.
var inputId = binding.createTensor([3], binding.TF_INT32, new Int32Array([1, 2, 3]));
var axesId = binding.createTensor([1], binding.TF_INT32, new Int32Array([0]));
var attrs = [
{ name: 'keep_dims', type: binding.TF_ATTR_BOOL, value: false },
{ name: 'T', type: binding.TF_ATTR_TYPE, value: binding.TF_INT32 },
{ name: 'Tidx', type: binding.TF_ATTR_TYPE, value: binding.TF_INT32 }
];
var outputMetadata = binding.executeOp('Max', attrs, [inputId, axesId], 1);
expect(outputMetadata.length).toBe(1);
expect(outputMetadata[0].id).toBeDefined();
expect(outputMetadata[0].shape).toEqual([]);
expect(outputMetadata[0].dtype).toEqual(binding.TF_INT32);
expect(binding.tensorDataSync(outputMetadata[0].id))
.toEqual(new Int32Array([3]));
});
});
describe('executeOp', function () {
var name = 'MatMul';
var matMulOpAttrs = [
{ name: 'transpose_a', type: binding.TF_ATTR_BOOL, value: false },
{ name: 'transpose_b', type: binding.TF_ATTR_BOOL, value: false },
{ name: 'T', type: binding.TF_ATTR_TYPE, value: binding.TF_FLOAT }
];
var aId = binding.createTensor([2, 2], binding.TF_FLOAT, new Float32Array([1, 2, 3, 4]));
var bId = binding.createTensor([2, 2], binding.TF_FLOAT, new Float32Array([4, 3, 2, 1]));
var matMulInput = [aId, bId];
it('throws exception with invalid Op Name', function () {
expect(function () {
binding.executeOp(null, [], [], null);
}).toThrowError();
});
it('throws exception with invalid TFEOpAttr', function () {
expect(function () {
binding.executeOp('Equal', null, [], null);
}).toThrowError();
});
it('throws excpetion with invalid inputs', function () {
expect(function () {
binding.executeOp(name, matMulOpAttrs, [], null);
}).toThrowError();
});
it('throws exception with invalid output number', function () {
expect(function () {
binding.executeOp(name, matMulOpAttrs, matMulInput, null);
}).toThrowError();
});
it('throws exception with invalid TF_ATTR_STRING op attr', function () {
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: null }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: false }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: 1 }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: new Object() }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: [1, 2, 3] }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
});
it('throws exception with invalid TF_ATTR_INT op attr', function () {
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: null }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: false }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: new Object() }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: 'test' }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
});
it('throws exception with invalid TF_ATTR_FLOAT op attr', function () {
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: null }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: false }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: new Object() }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: 'test' }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
});
it('throws exception with invalid TF_ATTR_BOOL op attr', function () {
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: null }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: 10 }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: new Object() }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: 'test' }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: [1, 2, 3] }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
});
it('throws exception with invalid TF_ATTR_TYPE op attr', function () {
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: null }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: new Object() }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: 'test' }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: [1, 2, 3] }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
});
it('throws exception with invalid TF_ATTR_SHAPE op attr', function () {
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: null }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: new Object() }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
expect(function () {
var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: 'test' }];
binding.executeOp(name, badOpAttrs, matMulInput, 1);
}).toThrowError();
});
it('should work for matmul', function () {
var output = binding.executeOp(name, matMulOpAttrs, matMulInput, 1);
expect(binding.tensorDataSync(output[0].id)).toEqual(new Float32Array([
8, 5, 20, 13
]));
});
});
;