@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
117 lines • 5.58 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var tf = require("../index");
var jasmine_util_1 = require("../jasmine_util");
var test_util_1 = require("../test_util");
jasmine_util_1.describeWithFlags('fused matmul', test_util_1.ALL_ENVS, function () {
it('A x B', function () {
var a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
var b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
var c = tf.fused.matMul(a, b);
expect(c.shape).toEqual([2, 2]);
test_util_1.expectArraysClose(c, [0, 8, -3, 20]);
});
it('A x B with relu', function () {
var a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
var b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
var c = tf.fused.matMul(a, b, false, false, null, 'relu');
expect(c.shape).toEqual([2, 2]);
test_util_1.expectArraysClose(c, [0, 8, 0, 20]);
});
it('A x B with relu transpose', function () {
var a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
var b = tf.tensor2d([0, 1, -3, 2, 2, 1], [2, 3]);
var c = tf.fused.matMul(a, b, false, true, null, 'relu');
expect(c.shape).toEqual([2, 2]);
test_util_1.expectArraysClose(c, [0, 9, 0, 24]);
});
it('A x B with relu and bias', function () {
var a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
var b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
var c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
var d = tf.fused.matMul(a, b, false, false, c, 'relu');
expect(d.shape).toEqual([2, 2]);
test_util_1.expectArraysClose(d, [1, 9, 0, 21]);
});
it('A x B with relu and broadcasted bias', function () {
var a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
var b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
var c = tf.tensor1d([1, 1]);
var act = 'relu';
var d = tf.fused.matMul(a, b, false, false, c, act);
expect(d.shape).toEqual([2, 2]);
test_util_1.expectArraysClose(d, [1, 9, 0, 21]);
});
it('A x B with relu and broadcasted bias different rank', function () {
var a = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [2, 2, 3]);
var b = tf.tensor3d([0, 1, -3, 2, 2, 1, 0, 1, -3, 2, 2, 1], [2, 3, 2]);
var c = tf.tensor2d([1, 2], [1, 2]);
var act = 'relu';
var d = tf.fused.matMul(a, b, false, false, c, act);
expect(d.shape).toEqual([2, 2, 2]);
test_util_1.expectArraysClose(d, [2, 6, 0, 18, 0, 30, 0, 42]);
});
it('A x B with bias only', function () {
var a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
var b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
var c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
var d = tf.fused.matMul(a, b, false, false, c, 'linear');
expect(d.shape).toEqual([2, 2]);
test_util_1.expectArraysClose(d, [1, 9, -2, 21]);
});
it('A x B with relu gradient', function () {
var a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
var b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
var dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
var grads = tf.grads(function (a, b) {
var prod = tf.matMul(a, b, false, false);
return tf.relu(prod);
});
var fusedGrads = tf.grads(function (a, b) {
return tf.fused.matMul(a, b, false, false, null, 'relu');
});
var _a = grads([a, b], dy), da = _a[0], db = _a[1];
var _b = fusedGrads([a, b], dy), fusedDa = _b[0], fusedDb = _b[1];
test_util_1.expectArraysClose(da, fusedDa);
test_util_1.expectArraysClose(db, fusedDb);
});
it('A x B with relu bias gradient', function () {
var a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
var b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
var c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
var dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
var grads = tf.grads(function (a, b, c) {
var prod = tf.matMul(a, b, false, false);
var sum = tf.add(prod, c);
return tf.relu(sum);
});
var fusedGrads = tf.grads(function (a, b, c) {
return tf.fused.matMul(a, b, false, false, c, 'relu');
});
var _a = grads([a, b, c], dy), da = _a[0], db = _a[1], dc = _a[2];
var _b = fusedGrads([a, b, c], dy), fusedDa = _b[0], fusedDb = _b[1], fusedDc = _b[2];
test_util_1.expectArraysClose(da, fusedDa);
test_util_1.expectArraysClose(db, fusedDb);
test_util_1.expectArraysClose(dc, fusedDc);
});
it('A x B with relu bias gradient transpose', function () {
var a = tf.tensor2d([1, 2, 3, 10, 20, -30], [3, 2]);
var b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
var c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
var dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
var grads = tf.grads(function (a, b, c) {
var prod = tf.matMul(a, b, true, false);
var sum = tf.add(prod, c);
return tf.relu(sum);
});
var fusedGrads = tf.grads(function (a, b, c) {
return tf.fused.matMul(a, b, true, false, c, 'relu');
});
var _a = grads([a, b, c], dy), da = _a[0], db = _a[1], dc = _a[2];
var _b = fusedGrads([a, b, c], dy), fusedDa = _b[0], fusedDb = _b[1], fusedDc = _b[2];
test_util_1.expectArraysClose(da, fusedDa);
test_util_1.expectArraysClose(db, fusedDb);
test_util_1.expectArraysClose(dc, fusedDc);
});
});
//# sourceMappingURL=fused_test.js.map