UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

64 lines 2.81 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var tf = require("../../index"); var jasmine_util_1 = require("../../jasmine_util"); var test_util_1 = require("../../test_util"); jasmine_util_1.describeWithFlags('expensive reshape', test_util_1.PACKED_ENVS, function () { var cValues = [46, 52, 58, 64, 70, 100, 115, 130, 145, 160, 154, 178, 202, 226, 250]; var c; beforeEach(function () { var a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3]); var b = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [3, 5]); c = tf.matMul(a, b); }); it('6d --> 1d', function () { var cAs6D = tf.reshape(c, [1, 1, 1, 3, 1, 5]); var cAs1D = tf.reshape(cAs6D, [-1, cValues.length]); test_util_1.expectArraysClose(cAs1D, cValues); }); it('1d --> 2d', function () { var cAs1D = tf.reshape(c, [cValues.length]); var cAs2D = tf.reshape(cAs1D, [5, -1]); test_util_1.expectArraysClose(cAs2D, cValues); }); it('2d --> 3d', function () { var cAs3D = tf.reshape(c, [3, 1, 5]); test_util_1.expectArraysClose(cAs3D, cValues); }); it('3d --> 4d', function () { var cAs3D = tf.reshape(c, [3, 1, 5]); var cAs4D = tf.reshape(cAs3D, [3, 5, 1, 1]); test_util_1.expectArraysClose(cAs4D, cValues); }); it('4d --> 5d', function () { var cAs4D = tf.reshape(c, [3, 5, 1, 1]); var cAs5D = tf.reshape(cAs4D, [1, 1, 1, 5, 3]); test_util_1.expectArraysClose(cAs5D, cValues); }); it('5d --> 6d', function () { var cAs5D = tf.reshape(c, [1, 1, 1, 5, 3]); var cAs6D = tf.reshape(cAs5D, [3, 5, 1, 1, 1, 1]); test_util_1.expectArraysClose(cAs6D, cValues); }); }); jasmine_util_1.describeWithFlags('expensive reshape with even columns', test_util_1.PACKED_ENVS, function () { it('2 --> 4 columns', function () { var maxTextureSize = tf.ENV.get('WEBGL_MAX_TEXTURE_SIZE'); var values = new Array(16).fill(0); values = values.map(function (d, i) { return i + 1; }); var a = tf.tensor2d(values, [8, 2]); var b = tf.tensor2d([1, 2, 3, 4], [2, 2]); tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', 2); var c = tf.matMul(a, b); var cAs4D = c.reshape([2, 1, 2, 4]); tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); var webglPackFlagSaved = tf.ENV.get('WEBGL_PACK'); tf.ENV.set('WEBGL_PACK', false); cAs4D = cAs4D.add(1); cAs4D = cAs4D.add(-1); tf.ENV.set('WEBGL_PACK', webglPackFlagSaved); var result = [7, 10, 15, 22, 23, 34, 31, 46, 39, 58, 47, 70, 55, 82, 63, 94]; test_util_1.expectArraysClose(cAs4D, result); }); }); //# sourceMappingURL=reshape_packed_test.js.map