UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

72 lines 11.4 kB
/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { ENGINE } from '../engine'; import { AvgPool3DGrad } from '../kernel_names'; import { convertToTensor } from '../tensor_util_env'; import * as util from '../util'; import { checkPadOnDimRoundingMode } from './conv_util'; import { op } from './operation'; import { reshape } from './reshape'; /** * Computes the backprop of a 3d avg pool. * * @param dy The dy error, of rank 5 of shape * [batchSize, depth, height, width, channels]. * assumed. * @param input The original input image, of rank 5 or rank4 of shape * [batchSize, depth, height, width, channels]. * @param filterSize The filter size: * `[filterDepth, filterHeight, filterWidth]`. * `filterSize` is a single number, * then `filterDepth == filterHeight == filterWidth`. * @param strides The strides of the pooling: * `[strideDepth, strideHeight, strideWidth]`. If * `strides` is a single number, then `strideHeight == strideWidth`. * @param pad A string from: 'same', 'valid'. The type of padding algorithm * used in the forward prop of the op. * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. */ function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) { const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad'); const $input = convertToTensor(input, 'input', 'avgPool3dGrad'); let dy5D = $dy; let input5D = $input; let reshapedTo5D = false; if ($input.rank === 4) { reshapedTo5D = true; dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]); input5D = reshape($input, [ 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3] ]); } util.assert(dy5D.rank === 5, () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` + `${dy5D.rank}.`); util.assert(input5D.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` + `${input5D.rank}.`); checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode); const inputs = { dy: dy5D, input: input5D }; const attrs = { filterSize, strides, pad, dimRoundingMode }; // tslint:disable-next-line: no-unnecessary-type-assertion const res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs); if (reshapedTo5D) { return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); } return res; } export const avgPool3dGrad = /* @__PURE__ */ op({ avgPool3dGrad_ }); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"avg_pool_3d_grad.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/avg_pool_3d_grad.ts"],"names":[],"mappings":"AACA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAC,MAAM,WAAW,CAAC;AACjC,OAAO,EAAC,aAAa,EAA0C,MAAM,iBAAiB,CAAC;AAIvF,OAAO,EAAC,eAAe,EAAC,MAAM,oBAAoB,CAAC;AAEnD,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAEhC,OAAO,EAAC,yBAAyB,EAAC,MAAM,aAAa,CAAC;AACtD,OAAO,EAAC,EAAE,EAAC,MAAM,aAAa,CAAC;AAC/B,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC;;;;;;;;;;;;;;;;;;;GAmBG;AACH,SAAS,cAAc,CACnB,EAAgB,EAAE,KAAmB,EACrC,UAA2C,EAC3C,OAAwC,EAAE,GAA0B,EACpE,eAAwC;IAC1C,MAAM,GAAG,GAAG,eAAe,CAAC,EAAE,EAAE,IAAI,EAAE,eAAe,CAAC,CAAC;IACvD,MAAM,MAAM,GAAG,eAAe,CAAC,KAAK,EAAE,OAAO,EAAE,eAAe,CAAC,CAAC;IAEhE,IAAI,IAAI,GAAG,GAAe,CAAC;IAC3B,IAAI,OAAO,GAAG,MAAkB,CAAC;IACjC,IAAI,YAAY,GAAG,KAAK,CAAC;IAEzB,IAAI,MAAM,CAAC,IAAI,KAAK,CAAC,EAAE;QACrB,YAAY,GAAG,IAAI,CAAC;QACpB,IAAI,GAAG,OAAO,CACV,GAAG,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACtE,OAAO,GAAG,OAAO,CAAC,MAAM,EAAE;YACxB,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC;SACtE,CAAC,CAAC;KACJ;IAED,IAAI,CAAC,MAAM,CACP,IAAI,CAAC,IAAI,KAAK,CAAC,EACf,GAAG,EAAE,CAAC,yDAAyD;QAC3D,GAAG,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC;IACzB,IAAI,CAAC,MAAM,CACP,OAAO,CAAC,IAAI,KAAK,CAAC,EAClB,GAAG,EAAE,CAAC,4DAA4D;QAC9D,GAAG,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;IAC5B,yBAAyB,CAAC,eAAe,EAAE,GAAG,EAAE,eAAe,CAAC,CAAC;IACjE,MAAM,MAAM,GAAwB,EAAC,EAAE,EAAE,IAAI,EAAE,KAAK,EAAE,OAAO,EAAC,CAAC;IAC/D,MAAM,KAAK,GAAuB,EAAC,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,eAAe,EAAC,CAAC;IAE9E,0DAA0D;IAC1D,MAAM,GAAG,GAAG,MAAM,CAAC,SAAS,CACZ,aAAa,EAAE,MAAmC,EAClD,KAAgC,CAAM,CAAC;IAEvD,IAAI,YAAY,EAAE;QAChB,OAAO,OAAO,CACH,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CACnE,CAAC;KACP;IAED,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,aAAa,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,cAAc,EAAC,CAAC,CAAC","sourcesContent":["\n/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {AvgPool3DGrad, AvgPool3DGradAttrs, AvgPool3DGradInputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor4D, Tensor5D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {checkPadOnDimRoundingMode} from './conv_util';\nimport {op} from './operation';\nimport {reshape} from './reshape';\n\n/**\n * Computes the backprop of a 3d avg pool.\n *\n * @param dy The dy error, of rank 5 of shape\n *     [batchSize, depth, height, width, channels].\n * assumed.\n * @param input The original input image, of rank 5 or rank4 of shape\n *     [batchSize, depth, height, width, channels].\n * @param filterSize The filter size:\n *     `[filterDepth, filterHeight, filterWidth]`.\n *     `filterSize` is a single number,\n *     then `filterDepth == filterHeight == filterWidth`.\n * @param strides The strides of the pooling:\n *     `[strideDepth, strideHeight, strideWidth]`. If\n *     `strides` is a single number, then `strideHeight == strideWidth`.\n * @param pad A string from: 'same', 'valid'. The type of padding algorithm\n *     used in the forward prop of the op.\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is\n *     provided, it will default to truncate.\n */\nfunction avgPool3dGrad_<T extends Tensor4D|Tensor5D>(\n    dy: T|TensorLike, input: T|TensorLike,\n    filterSize: [number, number, number]|number,\n    strides: [number, number, number]|number, pad: 'valid'|'same'|number,\n    dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n  const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');\n  const $input = convertToTensor(input, 'input', 'avgPool3dGrad');\n\n  let dy5D = $dy as Tensor5D;\n  let input5D = $input as Tensor5D;\n  let reshapedTo5D = false;\n\n  if ($input.rank === 4) {\n    reshapedTo5D = true;\n    dy5D = reshape(\n        $dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);\n    input5D = reshape($input, [\n      1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]\n    ]);\n  }\n\n  util.assert(\n      dy5D.rank === 5,\n      () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` +\n          `${dy5D.rank}.`);\n  util.assert(\n      input5D.rank === 5,\n      () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +\n          `${input5D.rank}.`);\n  checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);\n  const inputs: AvgPool3DGradInputs = {dy: dy5D, input: input5D};\n  const attrs: AvgPool3DGradAttrs = {filterSize, strides, pad, dimRoundingMode};\n\n  // tslint:disable-next-line: no-unnecessary-type-assertion\n  const res = ENGINE.runKernel(\n                  AvgPool3DGrad, inputs as unknown as NamedTensorMap,\n                  attrs as unknown as NamedAttrMap) as T;\n\n  if (reshapedTo5D) {\n    return reshape(\n               res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]) as\n        T;\n  }\n\n  return res;\n}\n\nexport const avgPool3dGrad = /* @__PURE__ */ op({avgPool3dGrad_});\n"]}