UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

22 lines 3.99 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var Conv2DProgram = (function () { function Conv2DProgram(convInfo) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; var padTop = convInfo.padInfo.top; var padLeft = convInfo.padInfo.left; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationHeight = convInfo.dilationHeight; var dilationWidth = convInfo.dilationWidth; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; var inputDepthVec4Remainder = convInfo.inChannels % 4; this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n dotProd +=\n getX(batch, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n vec2 wValues = vec2(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n vec3 wValues = vec3(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n setOutput(dotProd);\n }\n "; } return Conv2DProgram; }()); exports.Conv2DProgram = Conv2DProgram; //# sourceMappingURL=conv_gpu.js.map