@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
23 lines • 2.51 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var AvgPool2DBackpropProgram = (function () {
function AvgPool2DBackpropProgram(convInfo) {
this.variableNames = ['dy'];
this.outputShape = convInfo.inShape;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var avgMultiplier = 1 / (filterHeight * filterWidth);
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC+= " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n ";
}
return AvgPool2DBackpropProgram;
}());
exports.AvgPool2DBackpropProgram = AvgPool2DBackpropProgram;
//# sourceMappingURL=avg_pool_backprop_gpu.js.map