UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

114 lines 5.59 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var tf = require("../index"); var jasmine_util_1 = require("../jasmine_util"); var test_util_1 = require("../test_util"); jasmine_util_1.describeWithFlags('transpose', test_util_1.ALL_ENVS, function () { it('of scalar is no-op', function () { var a = tf.scalar(3); test_util_1.expectArraysClose(tf.transpose(a), [3]); }); it('of 1D is no-op', function () { var a = tf.tensor1d([1, 2, 3]); test_util_1.expectArraysClose(tf.transpose(a), [1, 2, 3]); }); it('of scalar with perm of incorrect rank throws error', function () { var a = tf.scalar(3); var perm = [0]; expect(function () { return tf.transpose(a, perm); }).toThrowError(); }); it('of 1d with perm out of bounds throws error', function () { var a = tf.tensor1d([1, 2, 3]); var perm = [1]; expect(function () { return tf.transpose(a, perm); }).toThrowError(); }); it('of 1d with perm incorrect rank throws error', function () { var a = tf.tensor1d([1, 2, 3]); var perm = [0, 0]; expect(function () { return tf.transpose(a, perm); }).toThrowError(); }); it('2D (no change)', function () { var t = tf.tensor2d([1, 11, 2, 22, 3, 33, 4, 44], [2, 4]); var t2 = tf.transpose(t, [0, 1]); expect(t2.shape).toEqual(t.shape); test_util_1.expectArraysClose(t2, t); }); it('2D (transpose)', function () { var t = tf.tensor2d([1, 11, 2, 22, 3, 33, 4, 44], [2, 4]); var t2 = tf.transpose(t, [1, 0]); expect(t2.shape).toEqual([4, 2]); test_util_1.expectArraysClose(t2, [1, 3, 11, 33, 2, 4, 22, 44]); }); it('3D [r, c, d] => [d, r, c]', function () { var t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]); var t2 = tf.transpose(t, [2, 0, 1]); expect(t2.shape).toEqual([2, 2, 2]); test_util_1.expectArraysClose(t2, [1, 2, 3, 4, 11, 22, 33, 44]); }); it('3D [r, c, d] => [d, c, r]', function () { var t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]); var t2 = tf.transpose(t, [2, 1, 0]); expect(t2.shape).toEqual([2, 2, 2]); test_util_1.expectArraysClose(t2, [1, 3, 2, 4, 11, 33, 22, 44]); }); it('5D [r, c, d, e, f] => [r, c, d, f, e]', function () { var t = tf.tensor5d(new Array(32).fill(0).map(function (x, i) { return i + 1; }), [2, 2, 2, 2, 2]); var t2 = tf.transpose(t, [0, 1, 2, 4, 3]); expect(t2.shape).toEqual([2, 2, 2, 2, 2]); test_util_1.expectArraysClose(t2, [ 1, 3, 2, 4, 5, 7, 6, 8, 9, 11, 10, 12, 13, 15, 14, 16, 17, 19, 18, 20, 21, 23, 22, 24, 25, 27, 26, 28, 29, 31, 30, 32 ]); }); it('5D [r, c, d, e, f] => [c, r, d, e, f]', function () { var t = tf.tensor5d(new Array(32).fill(0).map(function (x, i) { return i + 1; }), [2, 2, 2, 2, 2]); var t2 = tf.transpose(t, [1, 0, 2, 3, 4]); expect(t2.shape).toEqual([2, 2, 2, 2, 2]); test_util_1.expectArraysClose(t2, [ 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 25, 26, 27, 28, 29, 30, 31, 32 ]); }); it('6D [r, c, d, e, f] => [r, c, d, f, e]', function () { var t = tf.tensor6d(new Array(64).fill(0).map(function (x, i) { return i + 1; }), [2, 2, 2, 2, 2, 2]); var t2 = tf.transpose(t, [0, 1, 2, 3, 5, 4]); expect(t2.shape).toEqual([2, 2, 2, 2, 2, 2]); test_util_1.expectArraysClose(t2, [ 1, 3, 2, 4, 5, 7, 6, 8, 9, 11, 10, 12, 13, 15, 14, 16, 17, 19, 18, 20, 21, 23, 22, 24, 25, 27, 26, 28, 29, 31, 30, 32, 33, 35, 34, 36, 37, 39, 38, 40, 41, 43, 42, 44, 45, 47, 46, 48, 49, 51, 50, 52, 53, 55, 54, 56, 57, 59, 58, 60, 61, 63, 62, 64 ]); }); it('6D [r, c, d, e, f, g] => [c, r, d, e, f, g]', function () { var t = tf.tensor6d(new Array(64).fill(0).map(function (x, i) { return i + 1; }), [2, 2, 2, 2, 2, 2]); var t2 = tf.transpose(t, [1, 0, 2, 3, 4, 5]); expect(t2.shape).toEqual([2, 2, 2, 2, 2, 2]); test_util_1.expectArraysClose(t2, [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 ]); }); it('gradient 3D [r, c, d] => [d, c, r]', function () { var t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]); var perm = [2, 1, 0]; var dy = tf.tensor3d([111, 211, 121, 221, 112, 212, 122, 222], [2, 2, 2]); var dt = tf.grad(function (t) { return t.transpose(perm); })(t, dy); expect(dt.shape).toEqual(t.shape); expect(dt.dtype).toEqual('float32'); test_util_1.expectArraysClose(dt, [111, 112, 121, 122, 211, 212, 221, 222]); }); it('throws when passed a non-tensor', function () { expect(function () { return tf.transpose({}); }) .toThrowError(/Argument 'x' passed to 'transpose' must be a Tensor/); }); it('accepts a tensor-like object', function () { var t = [[1, 11, 2, 22], [3, 33, 4, 44]]; var res = tf.transpose(t, [1, 0]); expect(res.shape).toEqual([4, 2]); test_util_1.expectArraysClose(res, [1, 3, 11, 33, 2, 4, 22, 44]); }); }); //# sourceMappingURL=transpose_test.js.map