@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
39 lines • 3.08 kB
JavaScript
"use strict";
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
var DepthwiseConv2DProgram = /** @class */ (function () {
function DepthwiseConv2DProgram(convInfo) {
this.variableNames = ['x', 'W'];
this.outputShape = convInfo.outShape;
var xNumRows = convInfo.inHeight;
var xNumCols = convInfo.inWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var channelMul = convInfo.outChannels / convInfo.inChannels;
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TODO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + xNumRows + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + xNumCols + ") {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n setOutput(dotProd);\n }\n ";
}
return DepthwiseConv2DProgram;
}());
exports.DepthwiseConv2DProgram = DepthwiseConv2DProgram;
//# sourceMappingURL=conv_gpu_depthwise.js.map