@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
70 lines • 11.2 kB
JavaScript
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '../index';
import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
import { expectArraysClose } from '../test_util';
describeWithFlags('conv3dTranspose', ALL_ENVS, () => {
// Reference Python TensorFlow code
// ```python
// import numpy as np
// import tensorflow as tf
// tf.enable_eager_execution()
// x = np.array([2], dtype = np.float32).reshape(1, 1, 1, 1, 1)
// w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2, 2, 2,
// 1, 1)
// tf.nn.conv3d_transpose(x, w, output_shape=[1, 2, 2, 2, 1], padding='VALID')
// ```
it('input=2x2x2x1,d2=1,f=2,s=1,p=valid', async () => {
const origInputDepth = 1;
const origOutputDepth = 1;
const inputShape = [1, 1, 1, origOutputDepth];
const fSize = 2;
const origPad = 'valid';
const origStride = 1;
const x = tf.tensor4d([2], inputShape);
const w = tf.tensor5d([5, 4, 8, 7, 1, 2, 6, 3], [fSize, fSize, fSize, origInputDepth, origOutputDepth]);
const result = tf.conv3dTranspose(x, w, [2, 2, 2, 1], origStride, origPad);
const expected = [10, 8, 16, 14, 2, 4, 12, 6];
expect(result.shape).toEqual([2, 2, 2, 1]);
expectArraysClose(await result.data(), expected);
});
// Reference Python TensorFlow code
// ```python
// import numpy as np
// import tensorflow as tf
// tf.enable_eager_execution()
// x = np.array([2, 3], dtype = np.float32).reshape(2, 1, 1, 1, 1, 1)
// w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2,
// 2, 2, 1, 1)
// tf.nn.conv3d_transpose(x, w, output_shape=[2, 2, 2, 2, 1], padding='VALID')
// ```
it('input=2x2x2x1,d2=1,f=2,s=1,p=valid, batch=2', async () => {
const origInputDepth = 1;
const origOutputDepth = 1;
const inputShape = [2, 1, 1, 1, origOutputDepth];
const fSize = 2;
const origPad = 'valid';
const origStride = 1;
const x = tf.tensor5d([2, 3], inputShape);
const w = tf.tensor5d([5, 4, 8, 7, 1, 2, 6, 3], [fSize, fSize, fSize, origInputDepth, origOutputDepth]);
const result = tf.conv3dTranspose(x, w, [2, 2, 2, 2, 1], origStride, origPad);
const expected = [10, 8, 16, 14, 2, 4, 12, 6, 15, 12, 24, 21, 3, 6, 18, 9];
expect(result.shape).toEqual([2, 2, 2, 2, 1]);
expectArraysClose(await result.data(), expected);
});
});
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29udjNkX3RyYW5zcG9zZV90ZXN0LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvY29udjNkX3RyYW5zcG9zZV90ZXN0LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sS0FBSyxFQUFFLE1BQU0sVUFBVSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxRQUFRLEVBQUUsaUJBQWlCLEVBQUMsTUFBTSxpQkFBaUIsQ0FBQztBQUM1RCxPQUFPLEVBQUMsaUJBQWlCLEVBQUMsTUFBTSxjQUFjLENBQUM7QUFFL0MsaUJBQWlCLENBQUMsaUJBQWlCLEVBQUUsUUFBUSxFQUFFLEdBQUcsRUFBRTtJQUNsRCxtQ0FBbUM7SUFDbkMsWUFBWTtJQUNaLHFCQUFxQjtJQUNyQiwwQkFBMEI7SUFDMUIsOEJBQThCO0lBQzlCLCtEQUErRDtJQUMvRCw4RUFBOEU7SUFDOUUsVUFBVTtJQUNWLDhFQUE4RTtJQUM5RSxNQUFNO0lBQ04sRUFBRSxDQUFDLG9DQUFvQyxFQUFFLEtBQUssSUFBSSxFQUFFO1FBQ2xELE1BQU0sY0FBYyxHQUFHLENBQUMsQ0FBQztRQUN6QixNQUFNLGVBQWUsR0FBRyxDQUFDLENBQUM7UUFDMUIsTUFBTSxVQUFVLEdBQ1osQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxlQUFlLENBQUMsQ0FBQztRQUMvQixNQUFNLEtBQUssR0FBRyxDQUFDLENBQUM7UUFDaEIsTUFBTSxPQUFPLEdBQUcsT0FBTyxDQUFDO1FBQ3hCLE1BQU0sVUFBVSxHQUFHLENBQUMsQ0FBQztRQUVyQixNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsVUFBVSxDQUFDLENBQUM7UUFDdkMsTUFBTSxDQUFDLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FDakIsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQ3hCLENBQUMsS0FBSyxFQUFFLEtBQUssRUFBRSxLQUFLLEVBQUUsY0FBYyxFQUFFLGVBQWUsQ0FBQyxDQUFDLENBQUM7UUFFNUQsTUFBTSxNQUFNLEdBQUcsRUFBRSxDQUFDLGVBQWUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsVUFBVSxFQUFFLE9BQU8sQ0FBQyxDQUFDO1FBQzNFLE1BQU0sUUFBUSxHQUFHLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRSxFQUFFLEVBQUUsRUFBRSxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDO1FBRTlDLE1BQU0sQ0FBQyxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUMzQyxpQkFBaUIsQ0FBQyxNQUFNLE1BQU0sQ0FBQyxJQUFJLEVBQUUsRUFBRSxRQUFRLENBQUMsQ0FBQztJQUNuRCxDQUFDLENBQUMsQ0FBQztJQUVILG1DQUFtQztJQUNuQyxZQUFZO0lBQ1oscUJBQXFCO0lBQ3JCLDBCQUEwQjtJQUMxQiw4QkFBOEI7SUFDOUIscUVBQXFFO0lBQ3JFLHdFQUF3RTtJQUN4RSxnQkFBZ0I7SUFDaEIsOEVBQThFO0lBQzlFLE1BQU07SUFDTixFQUFFLENBQUMsNkNBQTZDLEVBQUUsS0FBSyxJQUFJLEVBQUU7UUFDM0QsTUFBTSxjQUFjLEdBQUcsQ0FBQyxDQUFDO1FBQ3pCLE1BQU0sZUFBZSxHQUFHLENBQUMsQ0FBQztRQUMxQixNQUFNLFVBQVUsR0FDWixDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxlQUFlLENBQUMsQ0FBQztRQUNsQyxNQUFNLEtBQUssR0FBRyxDQUFDLENBQUM7UUFDaEIsTUFBTSxPQUFPLEdBQUcsT0FBTyxDQUFDO1FBQ3hCLE1BQU0sVUFBVSxHQUFHLENBQUMsQ0FBQztRQUVyQixNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLFVBQVUsQ0FBQyxDQUFDO1FBQzFDLE1BQU0sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxRQUFRLENBQ2pCLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUN4QixDQUFDLEtBQUssRUFBRSxLQUFLLEVBQUUsS0FBSyxFQUFFLGNBQWMsRUFBRSxlQUFlLENBQUMsQ0FBQyxDQUFDO1FBRTVELE1BQU0sTUFBTSxHQUNSLEVBQUUsQ0FBQyxlQUFlLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxVQUFVLEVBQUUsT0FBTyxDQUFDLENBQUM7UUFDbkUsTUFBTSxRQUFRLEdBQUcsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxFQUFFLEVBQUUsRUFBRSxFQUFFLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLEVBQUUsQ0FBQyxFQUFFLEVBQUUsRUFBRSxFQUFFLEVBQUUsRUFBRSxFQUFFLEVBQUUsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLEVBQUUsRUFBRSxDQUFDLENBQUMsQ0FBQztRQUUzRSxNQUFNLENBQUMsTUFBTSxDQUFDLEtBQUssQ0FBQyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzlDLGlCQUFpQixDQUFDLE1BQU0sTUFBTSxDQUFDLElBQUksRUFBRSxFQUFFLFFBQVEsQ0FBQyxDQUFDO0lBQ25ELENBQUMsQ0FBQyxDQUFDO0FBQ0wsQ0FBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxNyBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCAqIGFzIHRmIGZyb20gJy4uL2luZGV4JztcbmltcG9ydCB7QUxMX0VOVlMsIGRlc2NyaWJlV2l0aEZsYWdzfSBmcm9tICcuLi9qYXNtaW5lX3V0aWwnO1xuaW1wb3J0IHtleHBlY3RBcnJheXNDbG9zZX0gZnJvbSAnLi4vdGVzdF91dGlsJztcblxuZGVzY3JpYmVXaXRoRmxhZ3MoJ2NvbnYzZFRyYW5zcG9zZScsIEFMTF9FTlZTLCAoKSA9PiB7XG4gIC8vIFJlZmVyZW5jZSBQeXRob24gVGVuc29yRmxvdyBjb2RlXG4gIC8vIGBgYHB5dGhvblxuICAvLyBpbXBvcnQgbnVtcHkgYXMgbnBcbiAgLy8gaW1wb3J0IHRlbnNvcmZsb3cgYXMgdGZcbiAgLy8gdGYuZW5hYmxlX2VhZ2VyX2V4ZWN1dGlvbigpXG4gIC8vIHggPSBucC5hcnJheShbMl0sIGR0eXBlID0gbnAuZmxvYXQzMikucmVzaGFwZSgxLCAxLCAxLCAxLCAxKVxuICAvLyB3ID0gbnAuYXJyYXkoWzUsIDQsIDgsIDcsIDEsIDIsIDYsIDNdLCBkdHlwZSA9IG5wLmZsb2F0MzIpLnJlc2hhcGUoMiwgMiwgMixcbiAgLy8gICAxLCAxKVxuICAvLyB0Zi5ubi5jb252M2RfdHJhbnNwb3NlKHgsIHcsIG91dHB1dF9zaGFwZT1bMSwgMiwgMiwgMiwgMV0sIHBhZGRpbmc9J1ZBTElEJylcbiAgLy8gYGBgXG4gIGl0KCdpbnB1dD0yeDJ4MngxLGQyPTEsZj0yLHM9MSxwPXZhbGlkJywgYXN5bmMgKCkgPT4ge1xuICAgIGNvbnN0IG9yaWdJbnB1dERlcHRoID0gMTtcbiAgICBjb25zdCBvcmlnT3V0cHV0RGVwdGggPSAxO1xuICAgIGNvbnN0IGlucHV0U2hhcGU6IFtudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXJdID1cbiAgICAgICAgWzEsIDEsIDEsIG9yaWdPdXRwdXREZXB0aF07XG4gICAgY29uc3QgZlNpemUgPSAyO1xuICAgIGNvbnN0IG9yaWdQYWQgPSAndmFsaWQnO1xuICAgIGNvbnN0IG9yaWdTdHJpZGUgPSAxO1xuXG4gICAgY29uc3QgeCA9IHRmLnRlbnNvcjRkKFsyXSwgaW5wdXRTaGFwZSk7XG4gICAgY29uc3QgdyA9IHRmLnRlbnNvcjVkKFxuICAgICAgICBbNSwgNCwgOCwgNywgMSwgMiwgNiwgM10sXG4gICAgICAgIFtmU2l6ZSwgZlNpemUsIGZTaXplLCBvcmlnSW5wdXREZXB0aCwgb3JpZ091dHB1dERlcHRoXSk7XG5cbiAgICBjb25zdCByZXN1bHQgPSB0Zi5jb252M2RUcmFuc3Bvc2UoeCwgdywgWzIsIDIsIDIsIDFdLCBvcmlnU3RyaWRlLCBvcmlnUGFkKTtcbiAgICBjb25zdCBleHBlY3RlZCA9IFsxMCwgOCwgMTYsIDE0LCAyLCA0LCAxMiwgNl07XG5cbiAgICBleHBlY3QocmVzdWx0LnNoYXBlKS50b0VxdWFsKFsyLCAyLCAyLCAxXSk7XG4gICAgZXhwZWN0QXJyYXlzQ2xvc2UoYXdhaXQgcmVzdWx0LmRhdGEoKSwgZXhwZWN0ZWQpO1xuICB9KTtcblxuICAvLyBSZWZlcmVuY2UgUHl0aG9uIFRlbnNvckZsb3cgY29kZVxuICAvLyBgYGBweXRob25cbiAgLy8gaW1wb3J0IG51bXB5IGFzIG5wXG4gIC8vIGltcG9ydCB0ZW5zb3JmbG93IGFzIHRmXG4gIC8vIHRmLmVuYWJsZV9lYWdlcl9leGVjdXRpb24oKVxuICAvLyB4ID0gbnAuYXJyYXkoWzIsIDNdLCBkdHlwZSA9IG5wLmZsb2F0MzIpLnJlc2hhcGUoMiwgMSwgMSwgMSwgMSwgMSlcbiAgLy8gdyA9IG5wLmFycmF5KFs1LCA0LCA4LCA3LCAxLCAyLCA2LCAzXSwgZHR5cGUgPSBucC5mbG9hdDMyKS5yZXNoYXBlKDIsXG4gIC8vICAgMiwgMiwgMSwgMSlcbiAgLy8gdGYubm4uY29udjNkX3RyYW5zcG9zZSh4LCB3LCBvdXRwdXRfc2hhcGU9WzIsIDIsIDIsIDIsIDFdLCBwYWRkaW5nPSdWQUxJRCcpXG4gIC8vIGBgYFxuICBpdCgnaW5wdXQ9MngyeDJ4MSxkMj0xLGY9MixzPTEscD12YWxpZCwgYmF0Y2g9MicsIGFzeW5jICgpID0+IHtcbiAgICBjb25zdCBvcmlnSW5wdXREZXB0aCA9IDE7XG4gICAgY29uc3Qgb3JpZ091dHB1dERlcHRoID0gMTtcbiAgICBjb25zdCBpbnB1dFNoYXBlOiBbbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXJdID1cbiAgICAgICAgWzIsIDEsIDEsIDEsIG9yaWdPdXRwdXREZXB0aF07XG4gICAgY29uc3QgZlNpemUgPSAyO1xuICAgIGNvbnN0IG9yaWdQYWQgPSAndmFsaWQnO1xuICAgIGNvbnN0IG9yaWdTdHJpZGUgPSAxO1xuXG4gICAgY29uc3QgeCA9IHRmLnRlbnNvcjVkKFsyLCAzXSwgaW5wdXRTaGFwZSk7XG4gICAgY29uc3QgdyA9IHRmLnRlbnNvcjVkKFxuICAgICAgICBbNSwgNCwgOCwgNywgMSwgMiwgNiwgM10sXG4gICAgICAgIFtmU2l6ZSwgZlNpemUsIGZTaXplLCBvcmlnSW5wdXREZXB0aCwgb3JpZ091dHB1dERlcHRoXSk7XG5cbiAgICBjb25zdCByZXN1bHQgPVxuICAgICAgICB0Zi5jb252M2RUcmFuc3Bvc2UoeCwgdywgWzIsIDIsIDIsIDIsIDFdLCBvcmlnU3RyaWRlLCBvcmlnUGFkKTtcbiAgICBjb25zdCBleHBlY3RlZCA9IFsxMCwgOCwgMTYsIDE0LCAyLCA0LCAxMiwgNiwgMTUsIDEyLCAyNCwgMjEsIDMsIDYsIDE4LCA5XTtcblxuICAgIGV4cGVjdChyZXN1bHQuc2hhcGUpLnRvRXF1YWwoWzIsIDIsIDIsIDIsIDFdKTtcbiAgICBleHBlY3RBcnJheXNDbG9zZShhd2FpdCByZXN1bHQuZGF0YSgpLCBleHBlY3RlZCk7XG4gIH0pO1xufSk7XG4iXX0=