@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
154 lines • 5.5 kB
JavaScript
;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
/**
* Gets the new shape of the input Tensor after it's been reshaped
* to:
* [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
* inputShape[1], ..., inputShape[N-1]]
*
* See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
*/
function getReshaped(inputShape, blockShape, prod, batchToSpace) {
if (batchToSpace === void 0) { batchToSpace = true; }
var reshaped = [];
if (batchToSpace) {
reshaped = reshaped.concat(blockShape.slice(0));
reshaped.push(inputShape[0] / prod);
reshaped = reshaped.concat(inputShape.slice(1));
}
else {
reshaped = reshaped.concat(inputShape[0]);
var spatialLength = blockShape.length;
for (var i = 0; i < spatialLength; ++i) {
reshaped =
reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
}
reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
}
return reshaped;
}
exports.getReshaped = getReshaped;
/**
* Gets the permutation that will transpose the dimensions of the
* reshaped tensor to shape:
*
* [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
* inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
*
* see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
*/
function getPermuted(reshapedRank, blockShapeRank, batchToSpace) {
if (batchToSpace === void 0) { batchToSpace = true; }
var permuted = [];
if (batchToSpace) {
permuted.push(blockShapeRank);
for (var i = blockShapeRank + 1; i < reshapedRank; ++i) {
if (i <= 2 * blockShapeRank) {
permuted.push(i);
permuted.push(i - (blockShapeRank + 1));
}
else {
permuted.push(i);
}
}
}
else {
var permutedBeforeBatch = [];
var permutedAfterBatch = [];
for (var i = 1; i < reshapedRank; ++i) {
if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
permutedAfterBatch.push(i);
}
else {
permutedBeforeBatch.push(i);
}
}
permuted.push.apply(permuted, permutedBeforeBatch);
permuted.push(0);
permuted.push.apply(permuted, permutedAfterBatch);
}
return permuted;
}
exports.getPermuted = getPermuted;
/**
* Gets the shape of the reshaped and permuted input Tensor before any cropping
* is applied. The new shape will be:
*
* [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
* inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
*
* See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
*/
function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace) {
if (batchToSpace === void 0) { batchToSpace = true; }
var reshapedPermuted = [];
if (batchToSpace) {
reshapedPermuted.push(inputShape[0] / prod);
}
else {
reshapedPermuted.push(inputShape[0] * prod);
}
for (var i = 1; i < inputShape.length; ++i) {
if (i <= blockShape.length) {
if (batchToSpace) {
reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
}
else {
reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
}
}
else {
reshapedPermuted.push(inputShape[i]);
}
}
return reshapedPermuted;
}
exports.getReshapedPermuted = getReshapedPermuted;
/**
* Converts the crops argument into the beginning coordinates of a slice
* operation.
*/
function getSliceBeginCoords(crops, blockShape) {
var sliceBeginCoords = [0];
for (var i = 0; i < blockShape; ++i) {
sliceBeginCoords.push(crops[i][0]);
}
return sliceBeginCoords;
}
exports.getSliceBeginCoords = getSliceBeginCoords;
/**
* Converts the crops argument into the size of a slice operation. When
* combined with getSliceBeginCoords this function allows the reshaped and
* permuted Tensor to be cropped to its final output shape of:
*
* inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
* inputShape[M] * blockShape[M-1] -crops[M-1,0] -
* crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
*
* See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
*/
function getSliceSize(uncroppedShape, crops, blockShape) {
var sliceSize = uncroppedShape.slice(0, 1);
for (var i = 0; i < blockShape; ++i) {
sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
}
return sliceSize;
}
exports.getSliceSize = getSliceSize;
//# sourceMappingURL=array_ops_util.js.map