@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
79 lines • 5.89 kB
JavaScript
"use strict";
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
var ReduceProgram = /** @class */ (function () {
function ReduceProgram(reduceInfo, reduceType) {
this.variableNames = ['x'];
var windowSize = reduceInfo.windowSize;
var batchSize = reduceInfo.batchSize;
var inSize = reduceInfo.inSize;
var outSize = Math.ceil(inSize / windowSize);
this.outputShape = [batchSize, outSize];
var initializationValue = '0.0';
var compareOp = "";
if (reduceType === 'prod') {
initializationValue = '1.0';
}
else if (reduceType === 'min') {
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
initializationValue = '1.0 / 1e-20';
compareOp = "min";
}
else if (reduceType === 'max') {
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
initializationValue = '-1.0 / 1e-20';
compareOp = "max";
}
var returnValue = reduceType + "(" + reduceType + "(" + reduceType + "(" +
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (reduceType === 'sum') {
returnValue = "sumValue";
}
else if (reduceType === 'prod') {
returnValue = "prodValue";
}
else if (reduceType === 'all') {
returnValue = "allValue";
}
else if (reduceType === 'any') {
returnValue = "anyValue";
}
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
var windowSizeVec4Remainder = windowSize % 4;
var updateSnippet = "\n if (" + (reduceType === 'sum') + ") {\n sumValue += dot(values, ones);\n } else if (" + (reduceType === 'prod') + ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
var vecType = "vec4";
if (reduceType === 'all') {
initializationValue = '1.0';
updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ";
vecType = "bvec4";
}
else if (reduceType === 'any') {
initializationValue = '0.0';
updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ";
vecType = "bvec4";
}
var checkOutOfBounds = '';
if (inSize % windowSize > 0) {
checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
}
this.userCode = "\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
}
return ReduceProgram;
}());
exports.ReduceProgram = ReduceProgram;
//# sourceMappingURL=reduce_gpu.js.map