@tensorflow/tfjs-node
Version: 
This repository provides native TensorFlow execution in backend JavaScript applications under the Node.js runtime, accelerated by the TensorFlow C binary under the hood. It provides the same API as [TensorFlow.js](https://js.tensorflow.org/api/latest/).
256 lines (255 loc) • 11.7 kB
JavaScript
;
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
Object.defineProperty(exports, "__esModule", { value: true });
var path = require("path");
// tslint:disable-next-line:no-require-imports
var binary = require('@mapbox/node-pre-gyp');
var bindingPath = binary.find(path.resolve(path.join(__dirname, '../package.json')));
// tslint:disable-next-line:no-require-imports
var bindings = require(bindingPath);
var binding = bindings;
describe('Exposes TF_DataType enum values', function () {
    it('contains TF_FLOAT', function () {
        expect(binding.TF_FLOAT).toEqual(1);
    });
    it('contains TF_INT32', function () {
        expect(binding.TF_INT32).toEqual(3);
    });
    it('contains TF_BOOL', function () {
        expect(binding.TF_BOOL).toEqual(10);
    });
    it('contains TF_COMPLEX64', function () {
        expect(binding.TF_COMPLEX64).toEqual(8);
    });
    it('contains TF_STRING', function () {
        expect(binding.TF_STRING).toEqual(7);
    });
});
describe('Exposes TF_AttrType enum values', function () {
    it('contains TF_ATTR_STRING', function () {
        expect(binding.TF_ATTR_STRING).toEqual(0);
    });
    it('contains TF_ATTR_INT', function () {
        expect(binding.TF_ATTR_INT).toEqual(1);
    });
    it('contains TF_ATTR_FLOAT', function () {
        expect(binding.TF_ATTR_FLOAT).toEqual(2);
    });
    it('contains TF_ATTR_BOOL', function () {
        expect(binding.TF_ATTR_BOOL).toEqual(3);
    });
    it('contains TF_ATTR_TYPE', function () {
        expect(binding.TF_ATTR_TYPE).toEqual(4);
    });
    it('contains TF_ATTR_SHAPE', function () {
        expect(binding.TF_ATTR_SHAPE).toEqual(5);
    });
});
describe('Exposes TF Version', function () {
    it('contains a version string', function () {
        expect(binding.TF_Version).toBeDefined();
    });
});
describe('tensor management', function () {
    it('Creates and deletes a valid tensor', function () {
        var values = new Int32Array([1, 2]);
        var id = binding.createTensor([2], binding.TF_INT32, values);
        expect(id).toBeDefined();
        binding.deleteTensor(id);
    });
    it('throws exception when shape does not match data', function () {
        expect(function () {
            binding.createTensor([2], binding.TF_INT32, new Int32Array([1, 2, 3]));
        }).toThrowError();
        expect(function () {
            binding.createTensor([4], binding.TF_INT32, new Int32Array([1, 2, 3]));
        }).toThrowError();
    });
    it('throws exception with invalid dtype', function () {
        expect(function () {
            // tslint:disable-next-line:no-unused-expression
            binding.createTensor([1], 1000, new Int32Array([1]));
        }).toThrowError();
    });
    it('works with 0-dim tensors', function () {
        // Reduce op (e.g 'Max') will produce a 0-dim TFE_Tensor.
        var inputId = binding.createTensor([3], binding.TF_INT32, new Int32Array([1, 2, 3]));
        var axesId = binding.createTensor([1], binding.TF_INT32, new Int32Array([0]));
        var attrs = [
            { name: 'keep_dims', type: binding.TF_ATTR_BOOL, value: false },
            { name: 'T', type: binding.TF_ATTR_TYPE, value: binding.TF_INT32 },
            { name: 'Tidx', type: binding.TF_ATTR_TYPE, value: binding.TF_INT32 }
        ];
        var outputMetadata = binding.executeOp('Max', attrs, [inputId, axesId], 1);
        expect(outputMetadata.length).toBe(1);
        expect(outputMetadata[0].id).toBeDefined();
        expect(outputMetadata[0].shape).toEqual([]);
        expect(outputMetadata[0].dtype).toEqual(binding.TF_INT32);
        expect(binding.tensorDataSync(outputMetadata[0].id))
            .toEqual(new Int32Array([3]));
    });
});
describe('executeOp', function () {
    var name = 'MatMul';
    var matMulOpAttrs = [
        { name: 'transpose_a', type: binding.TF_ATTR_BOOL, value: false },
        { name: 'transpose_b', type: binding.TF_ATTR_BOOL, value: false },
        { name: 'T', type: binding.TF_ATTR_TYPE, value: binding.TF_FLOAT }
    ];
    var aId = binding.createTensor([2, 2], binding.TF_FLOAT, new Float32Array([1, 2, 3, 4]));
    var bId = binding.createTensor([2, 2], binding.TF_FLOAT, new Float32Array([4, 3, 2, 1]));
    var matMulInput = [aId, bId];
    it('throws exception with invalid Op Name', function () {
        expect(function () {
            binding.executeOp(null, [], [], null);
        }).toThrowError();
    });
    it('throws exception with invalid TFEOpAttr', function () {
        expect(function () {
            binding.executeOp('Equal', null, [], null);
        }).toThrowError();
    });
    it('throws excpetion with invalid inputs', function () {
        expect(function () {
            binding.executeOp(name, matMulOpAttrs, [], null);
        }).toThrowError();
    });
    it('throws exception with invalid output number', function () {
        expect(function () {
            binding.executeOp(name, matMulOpAttrs, matMulInput, null);
        }).toThrowError();
    });
    it('throws exception with invalid TF_ATTR_STRING op attr', function () {
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: null }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: false }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: 1 }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: new Object() }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: [1, 2, 3] }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
    });
    it('throws exception with invalid TF_ATTR_INT op attr', function () {
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: null }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: false }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: new Object() }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: 'test' }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
    });
    it('throws exception with invalid TF_ATTR_FLOAT op attr', function () {
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: null }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: false }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: new Object() }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: 'test' }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
    });
    it('throws exception with invalid TF_ATTR_BOOL op attr', function () {
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: null }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: 10 }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: new Object() }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: 'test' }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: [1, 2, 3] }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
    });
    it('throws exception with invalid TF_ATTR_TYPE op attr', function () {
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: null }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: new Object() }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: 'test' }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: [1, 2, 3] }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
    });
    it('throws exception with invalid TF_ATTR_SHAPE op attr', function () {
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: null }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: new Object() }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
        expect(function () {
            var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: 'test' }];
            binding.executeOp(name, badOpAttrs, matMulInput, 1);
        }).toThrowError();
    });
    it('should work for matmul', function () {
        var output = binding.executeOp(name, matMulOpAttrs, matMulInput, 1);
        expect(binding.tensorDataSync(output[0].id)).toEqual(new Float32Array([
            8, 5, 20, 13
        ]));
    });
});