UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

154 lines (144 loc) 4.94 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Gets the new shape of the input Tensor after it's been reshaped * to: * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape), * inputShape[1], ..., inputShape[N-1]] * * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd */ export function getReshaped( inputShape: number[], blockShape: number[], prod: number, batchToSpace = true): number[] { let reshaped: number[] = []; if (batchToSpace) { reshaped = reshaped.concat(blockShape.slice(0)); reshaped.push(inputShape[0] / prod); reshaped = reshaped.concat(inputShape.slice(1)); } else { reshaped = reshaped.concat(inputShape[0]); const spatialLength = blockShape.length; for (let i = 0; i < spatialLength; ++i) { reshaped = reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]); } reshaped = reshaped.concat(inputShape.slice(spatialLength + 1)); } return reshaped; } /** * Gets the permutation that will transpose the dimensions of the * reshaped tensor to shape: * * [batch / prod(block_shape),inputShape[1], blockShape[0], ..., * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]] * * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd */ export function getPermuted( reshapedRank: number, blockShapeRank: number, batchToSpace = true): number[] { const permuted = []; if (batchToSpace) { permuted.push(blockShapeRank); for (let i = blockShapeRank + 1; i < reshapedRank; ++i) { if (i <= 2 * blockShapeRank) { permuted.push(i); permuted.push(i - (blockShapeRank + 1)); } else { permuted.push(i); } } } else { const permutedBeforeBatch = []; const permutedAfterBatch = []; for (let i = 1; i < reshapedRank; ++i) { if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) { permutedAfterBatch.push(i); } else { permutedBeforeBatch.push(i); } } permuted.push(...permutedBeforeBatch); permuted.push(0); permuted.push(...permutedAfterBatch); } return permuted; } /** * Gets the shape of the reshaped and permuted input Tensor before any cropping * is applied. The new shape will be: * * [batch / prod(blockShape),inputShape[1] * blockShape[0], ..., * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]] * * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd */ export function getReshapedPermuted( inputShape: number[], blockShape: number[], prod: number, batchToSpace = true): number[] { const reshapedPermuted = []; if (batchToSpace) { reshapedPermuted.push(inputShape[0] / prod); } else { reshapedPermuted.push(inputShape[0] * prod); } for (let i = 1; i < inputShape.length; ++i) { if (i <= blockShape.length) { if (batchToSpace) { reshapedPermuted.push(blockShape[i - 1] * inputShape[i]); } else { reshapedPermuted.push(inputShape[i] / blockShape[i - 1]); } } else { reshapedPermuted.push(inputShape[i]); } } return reshapedPermuted; } /** * Converts the crops argument into the beginning coordinates of a slice * operation. */ export function getSliceBeginCoords( crops: number[][], blockShape: number): number[] { const sliceBeginCoords = [0]; for (let i = 0; i < blockShape; ++i) { sliceBeginCoords.push(crops[i][0]); } return sliceBeginCoords; } /** * Converts the crops argument into the size of a slice operation. When * combined with getSliceBeginCoords this function allows the reshaped and * permuted Tensor to be cropped to its final output shape of: * * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ..., * inputShape[M] * blockShape[M-1] -crops[M-1,0] - * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]] * * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd */ export function getSliceSize( uncroppedShape: number[], crops: number[][], blockShape: number): number[] { const sliceSize = uncroppedShape.slice(0, 1); for (let i = 0; i < blockShape; ++i) { sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]); } return sliceSize; }