UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

39 lines 3.63 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var ReduceProgram = (function () { function ReduceProgram(reduceInfo, reduceType) { this.variableNames = ['x']; var windowSize = reduceInfo.windowSize; var batchSize = reduceInfo.batchSize; var inSize = reduceInfo.inSize; var outSize = Math.ceil(inSize / windowSize); this.outputShape = [batchSize, outSize]; var isReduceSum = reduceType === 'sum'; var initializationValue = '0.0'; if (!isReduceSum) { if (reduceType === 'min') { initializationValue = '1.0 / 0.0'; } else { initializationValue = '-1.0 / 0.0'; } } var compareOp = reduceType === 'min' ? 'min' : 'max'; var returnValue = reduceType + "(" + reduceType + "(" + reduceType + "(" + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (reduceType === 'sum') { returnValue = "sumValue"; } var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; var windowSizeVec4Remainder = windowSize % 4; var updateSnippet = "\n if (" + isReduceSum + ") {\n sumValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n "; var checkOutOfBounds = ''; if (inSize % windowSize > 0) { checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n "; } this.userCode = "\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n "; } return ReduceProgram; }()); exports.ReduceProgram = ReduceProgram; //# sourceMappingURL=reduce_gpu.js.map