@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
217 lines • 10.1 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var tf = require("../index");
var test_util_1 = require("../test_util");
var jasmine_util_1 = require("../jasmine_util");
jasmine_util_1.describeWithFlags('concat1d', test_util_1.ALL_ENVS, function () {
it('3 + 5', function () {
var a = tf.tensor1d([3]);
var b = tf.tensor1d([5]);
var result = tf.concat1d([a, b]);
var expected = [3, 5];
test_util_1.expectArraysClose(result, expected);
});
it('3 + [5,7]', function () {
var a = tf.tensor1d([3]);
var b = tf.tensor1d([5, 7]);
var result = tf.concat1d([a, b]);
var expected = [3, 5, 7];
test_util_1.expectArraysClose(result, expected);
});
it('[3,5] + 7', function () {
var a = tf.tensor1d([3, 5]);
var b = tf.tensor1d([7]);
var result = tf.concat1d([a, b]);
var expected = [3, 5, 7];
test_util_1.expectArraysClose(result, expected);
});
it('3 + 5 + 7 + 9', function () {
var a = tf.tensor1d([3]);
var b = tf.tensor1d([5]);
var c = tf.tensor1d([7]);
var d = tf.tensor1d([9]);
var result = tf.concat1d([a, b, c, d]);
test_util_1.expectArraysClose(result, [3, 5, 7, 9]);
});
it('single tensor', function () {
var a = tf.tensor1d([3]);
var result = tf.concat1d([a]);
test_util_1.expectArraysClose(result, [3]);
});
});
jasmine_util_1.describeWithFlags('concat2d', test_util_1.ALL_ENVS, function () {
it('[[3]] + [[5]], axis=0', function () {
var axis = 0;
var a = tf.tensor2d([3], [1, 1]);
var b = tf.tensor2d([5], [1, 1]);
var result = tf.concat2d([a, b], axis);
var expected = [3, 5];
expect(result.shape).toEqual([2, 1]);
test_util_1.expectArraysClose(result, expected);
});
it('[[3]] + [[5]], axis=1', function () {
var axis = 1;
var a = tf.tensor2d([3], [1, 1]);
var b = tf.tensor2d([5], [1, 1]);
var result = tf.concat2d([a, b], axis);
var expected = [3, 5];
expect(result.shape).toEqual([1, 2]);
test_util_1.expectArraysClose(result, expected);
});
it('[[1, 2], [3, 4]] + [[5, 6]], axis=0', function () {
var axis = 0;
var a = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
var b = tf.tensor2d([[5, 6]], [1, 2]);
var result = tf.concat2d([a, b], axis);
var expected = [1, 2, 3, 4, 5, 6];
expect(result.shape).toEqual([3, 2]);
test_util_1.expectArraysClose(result, expected);
});
it('[[1, 2],[3, 4]] + [[5, 6]] + [[7, 8]], axis=0', function () {
var axis = 0;
var a = tf.tensor2d([[1, 2], [3, 4]]);
var b = tf.tensor2d([[5, 6]]);
var c = tf.tensor2d([[7, 8]]);
var result = tf.concat2d([a, b, c], axis);
var expected = [1, 2, 3, 4, 5, 6, 7, 8];
expect(result.shape).toEqual([4, 2]);
test_util_1.expectArraysClose(result, expected);
});
it('[[1, 2], [3, 4]] + [[5, 6]], axis=1 throws error', function () {
var axis = 1;
var a = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
var b = tf.tensor2d([[5, 6]], [1, 2]);
expect(function () { return tf.concat2d([a, b], axis); }).toThrowError();
});
it('[[1, 2], [3, 4]] + [[5, 6], [7, 8]], axis=1', function () {
var axis = 1;
var a = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
var b = tf.tensor2d([[5, 6], [7, 8]], [2, 2]);
var result = tf.concat2d([a, b], axis);
var expected = [1, 2, 5, 6, 3, 4, 7, 8];
expect(result.shape).toEqual([2, 4]);
test_util_1.expectArraysClose(result, expected);
});
it('[[1, 2],[3, 4]] + [[5, 6],[7, 8]] + [[9, 10],[11, 12]], axis=1', function () {
var axis = 1;
var a = tf.tensor2d([[1, 2], [3, 4]]);
var b = tf.tensor2d([[5, 6], [7, 8]]);
var c = tf.tensor2d([[9, 10], [11, 12]]);
var result = tf.concat2d([a, b, c], axis);
var expected = [1, 2, 5, 6, 9, 10, 3, 4, 7, 8, 11, 12];
expect(result.shape).toEqual([2, 6]);
test_util_1.expectArraysClose(result, expected);
});
});
jasmine_util_1.describeWithFlags('concat3d', test_util_1.ALL_ENVS, function () {
it('shapes correct concat axis=0', function () {
var tensor1 = tf.tensor3d([1, 2, 3], [1, 1, 3]);
var tensor2 = tf.tensor3d([4, 5, 6], [1, 1, 3]);
var values = tf.concat3d([tensor1, tensor2], 0);
expect(values.shape).toEqual([2, 1, 3]);
test_util_1.expectArraysClose(values, [1, 2, 3, 4, 5, 6]);
});
it('concat axis=0', function () {
var tensor1 = tf.tensor3d([1, 11, 111, 2, 22, 222], [1, 2, 3]);
var tensor2 = tf.tensor3d([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888], [2, 2, 3]);
var values = tf.concat3d([tensor1, tensor2], 0);
expect(values.shape).toEqual([3, 2, 3]);
test_util_1.expectArraysClose(values, [
1, 11, 111, 2, 22, 222, 5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888
]);
});
it('shapes correct concat axis=1', function () {
var tensor1 = tf.tensor3d([1, 2, 3], [1, 1, 3]);
var tensor2 = tf.tensor3d([4, 5, 6], [1, 1, 3]);
var values = tf.concat3d([tensor1, tensor2], 1);
expect(values.shape).toEqual([1, 2, 3]);
test_util_1.expectArraysClose(values, [1, 2, 3, 4, 5, 6]);
});
it('concat axis=1', function () {
var tensor1 = tf.tensor3d([1, 11, 111, 3, 33, 333], [2, 1, 3]);
var tensor2 = tf.tensor3d([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888], [2, 2, 3]);
var values = tf.concat3d([tensor1, tensor2], 1);
expect(values.shape).toEqual([2, 3, 3]);
test_util_1.expectArraysClose(values, [
1, 11, 111, 5, 55, 555, 6, 66, 666, 3, 33, 333, 7, 77, 777, 8, 88, 888
]);
});
it('shapes correct concat axis=2', function () {
var tensor1 = tf.tensor3d([1, 2, 3], [1, 1, 3]);
var tensor2 = tf.tensor3d([4, 5, 6], [1, 1, 3]);
var values = tf.concat3d([tensor1, tensor2], 2);
expect(values.shape).toEqual([1, 1, 6]);
test_util_1.expectArraysClose(values, [1, 2, 3, 4, 5, 6]);
});
it('concat axis=2', function () {
var tensor1 = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
var tensor2 = tf.tensor3d([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888], [2, 2, 3]);
var values = tf.concat3d([tensor1, tensor2], 2);
expect(values.shape).toEqual([2, 2, 5]);
test_util_1.expectArraysClose(values, [
1, 11, 5, 55, 555, 2, 22, 6, 66, 666,
3, 33, 7, 77, 777, 4, 44, 8, 88, 888
]);
});
it('concat throws when invalid non-axis shapes, axis=0', function () {
var axis = 0;
var x1 = tf.tensor3d([1, 11, 111], [1, 1, 3]);
var x2 = tf.tensor3d([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888], [2, 2, 3]);
expect(function () { return tf.concat3d([x1, x2], axis); }).toThrowError();
});
it('concat throws when invalid non-axis shapes, axis=1', function () {
var axis = 1;
var x1 = tf.tensor3d([1, 11, 111], [1, 1, 3]);
var x2 = tf.tensor3d([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888], [2, 2, 3]);
expect(function () { return tf.concat3d([x1, x2], axis); }).toThrowError();
});
it('concat throws when invalid non-axis shapes, axis=2', function () {
var axis = 2;
var x1 = tf.tensor3d([1, 11, 2, 22], [1, 2, 2]);
var x2 = tf.tensor3d([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888], [2, 2, 3]);
expect(function () { return tf.concat3d([x1, x2], axis); }).toThrowError();
});
it('gradient concat axis=0', function () {
var x1 = tf.tensor3d([1, 11, 2, 22], [1, 2, 2]);
var x2 = tf.tensor3d([5, 55, 6, 66, 7, 77, 8, 88], [2, 2, 2]);
var dy = tf.tensor3d([66, 6, 55, 5, 44, 4, 33, 3, 22, 2, 11, 1], [3, 2, 2]);
var axis = 0;
var grads = tf.grads(function (x1, x2) { return tf.concat3d([x1, x2], axis); });
var _a = grads([x1, x2], dy), dx1 = _a[0], dx2 = _a[1];
expect(dx1.shape).toEqual(x1.shape);
test_util_1.expectArraysClose(dx1, [66, 6, 55, 5]);
expect(dx2.shape).toEqual(x2.shape);
test_util_1.expectArraysClose(dx2, [44, 4, 33, 3, 22, 2, 11, 1]);
});
it('gradient concat axis=1', function () {
var x1 = tf.tensor3d([1, 11, 2, 22], [2, 1, 2]);
var x2 = tf.tensor3d([3, 33, 4, 44, 5, 55, 6, 66], [2, 2, 2]);
var dy = tf.tensor3d([66, 6, 55, 5, 44, 4, 33, 3, 22, 2, 11, 1], [2, 3, 2]);
var axis = 1;
var grads = tf.grads(function (x1, x2) { return tf.concat3d([x1, x2], axis); });
var _a = grads([x1, x2], dy), dx1 = _a[0], dx2 = _a[1];
expect(dx1.shape).toEqual(x1.shape);
test_util_1.expectArraysClose(dx1, [66, 6, 33, 3]);
expect(dx2.shape).toEqual(x2.shape);
test_util_1.expectArraysClose(dx2, [55, 5, 44, 4, 22, 2, 11, 1]);
});
it('gradient concat axis=2', function () {
var x1 = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
var x2 = tf.tensor3d([5, 55, 6, 66, 7, 77, 8, 88], [2, 2, 2]);
var dy = tf.tensor3d([4, 40, 400, 3, 30, 300, 2, 20, 200, 1, 10, 100], [2, 2, 3]);
var axis = 2;
var grads = tf.grads(function (x1, x2) { return tf.concat3d([x1, x2], axis); });
var _a = grads([x1, x2], dy), dx1 = _a[0], dx2 = _a[1];
expect(dx1.shape).toEqual(x1.shape);
test_util_1.expectArraysClose(dx1, [4, 3, 2, 1]);
expect(dx2.shape).toEqual(x2.shape);
test_util_1.expectArraysClose(dx2, [40, 400, 30, 300, 20, 200, 10, 100]);
});
});
jasmine_util_1.describeWithFlags('concat throws for non-tensors', test_util_1.ALL_ENVS, function () {
it('throws when passed a non-tensor', function () {
expect(function () { return tf.concat([{}]); })
.toThrowError(/Argument 'tensors\[0\]' passed to 'concat' must be a Tensor/);
});
});
//# sourceMappingURL=concat_test.js.map