@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
68 lines • 3.08 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
var tf = require("../index");
var jasmine_util_1 = require("../jasmine_util");
var test_util_1 = require("../test_util");
jasmine_util_1.describeWithFlags('transpose', test_util_1.ALL_ENVS, function () {
it('of scalar is no-op', function () {
var a = tf.scalar(3);
test_util_1.expectArraysClose(tf.transpose(a), [3]);
});
it('of 1D is no-op', function () {
var a = tf.tensor1d([1, 2, 3]);
test_util_1.expectArraysClose(tf.transpose(a), [1, 2, 3]);
});
it('of scalar with perm of incorrect rank throws error', function () {
var a = tf.scalar(3);
var perm = [0];
expect(function () { return tf.transpose(a, perm); }).toThrowError();
});
it('of 1d with perm out of bounds throws error', function () {
var a = tf.tensor1d([1, 2, 3]);
var perm = [1];
expect(function () { return tf.transpose(a, perm); }).toThrowError();
});
it('of 1d with perm incorrect rank throws error', function () {
var a = tf.tensor1d([1, 2, 3]);
var perm = [0, 0];
expect(function () { return tf.transpose(a, perm); }).toThrowError();
});
it('2D (no change)', function () {
var t = tf.tensor2d([1, 11, 2, 22, 3, 33, 4, 44], [2, 4]);
var t2 = tf.transpose(t, [0, 1]);
expect(t2.shape).toEqual(t.shape);
test_util_1.expectArraysClose(t2, t);
});
it('2D (transpose)', function () {
var t = tf.tensor2d([1, 11, 2, 22, 3, 33, 4, 44], [2, 4]);
var t2 = tf.transpose(t, [1, 0]);
expect(t2.shape).toEqual([4, 2]);
test_util_1.expectArraysClose(t2, [1, 3, 11, 33, 2, 4, 22, 44]);
});
it('3D [r, c, d] => [d, r, c]', function () {
var t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
var t2 = tf.transpose(t, [2, 0, 1]);
expect(t2.shape).toEqual([2, 2, 2]);
test_util_1.expectArraysClose(t2, [1, 2, 3, 4, 11, 22, 33, 44]);
});
it('3D [r, c, d] => [d, c, r]', function () {
var t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
var t2 = tf.transpose(t, [2, 1, 0]);
expect(t2.shape).toEqual([2, 2, 2]);
test_util_1.expectArraysClose(t2, [1, 3, 2, 4, 11, 33, 22, 44]);
});
it('gradient 3D [r, c, d] => [d, c, r]', function () {
var t = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
var perm = [2, 1, 0];
var dy = tf.tensor3d([111, 211, 121, 221, 112, 212, 122, 222], [2, 2, 2]);
var dt = tf.grad(function (t) { return t.transpose(perm); })(t, dy);
expect(dt.shape).toEqual(t.shape);
expect(dt.dtype).toEqual('float32');
test_util_1.expectArraysClose(dt, [111, 112, 121, 122, 211, 212, 221, 222]);
});
it('throws when passed a non-tensor', function () {
expect(function () { return tf.transpose({}); })
.toThrowError(/Argument 'x' passed to 'transpose' must be a Tensor/);
});
});
//# sourceMappingURL=transpose_test.js.map