@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
140 lines • 6.94 kB
JavaScript
"use strict";
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
var ops = require("../../ops/ops");
var ops_1 = require("../../ops/ops");
function pool(xValues, xShape, dtype, strides, convInfo, poolType) {
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
Number.POSITIVE_INFINITY);
var output = ops.buffer(convInfo.outShape, dtype);
var outputVals = output.values;
var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
var outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
var outputColStrides = convInfo.outShape[3];
for (var b = 0; b < convInfo.batchSize; ++b) {
var outputBatchOffset = b * outputBatchStrides;
var inputBatchOffset = b * strides[0];
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var xRCorner = yR * strideHeight - padTop;
var xRMin = Math.max(0, xRCorner);
var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
var outputRowOffset = outputBatchOffset + yR * outputRowStrides;
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var xCCorner = yC * strideWidth - padLeft;
var xCMin = Math.max(0, xCCorner);
var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
var minMaxValue = initialValue;
var avgValue = 0;
var count = 0;
for (var xR = xRMin; xR < xRMax; xR += dilationHeight) {
var xROffset = inputBatchOffset + xR * strides[1];
for (var xC = xCMin; xC < xCMax; xC += dilationWidth) {
var xCOffset = xROffset + xC * strides[2];
var pixel = xValues[xCOffset + d];
if ((poolType === 'max' && pixel > minMaxValue)) {
minMaxValue = pixel;
}
else if (poolType === 'avg') {
avgValue += pixel;
count++;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
var outputOffset = outputRowOffset + yC * outputColStrides + d;
outputVals[outputOffset] =
poolType === 'avg' ? avgValue / count : minMaxValue;
}
}
}
}
return output;
}
exports.pool = pool;
function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions, includeBatchInIndex) {
if (flattenPositions === void 0) { flattenPositions = false; }
if (includeBatchInIndex === void 0) { includeBatchInIndex = false; }
var maxPositions = ops.buffer(convInfo.outShape, 'int32');
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var xBuf = ops_1.buffer(xShape, dtype, xValues);
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var xRCorner = yR * strideHeight - padTop;
var xRMin = xRCorner;
while (xRMin < 0) {
xRMin += dilationHeight;
}
// const xRMin = Math.max(0, xRCorner);
var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var xCCorner = yC * strideWidth - padLeft;
var xCMin = xCCorner;
while (xCMin < 0) {
xCMin += dilationWidth;
}
var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
var maxValue = Number.NEGATIVE_INFINITY;
var maxPosition = -1;
for (var xR = xRMin; xR < xRMax; xR += dilationHeight) {
var wR = xR - xRCorner;
for (var xC = xCMin; xC < xCMax; xC += dilationWidth) {
var wC = xC - xCCorner;
var pixel = xBuf.get(b, xR, xC, d);
if (pixel > maxValue) {
maxValue = pixel;
if (flattenPositions) {
maxPosition = includeBatchInIndex ?
((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *
convInfo.inChannels +
d :
(xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
}
else {
maxPosition = wR * effectiveFilterWidth + wC;
}
}
}
}
maxPositions.set(maxPosition, b, yR, yC, d);
}
}
}
}
return maxPositions;
}
exports.maxPoolPositions = maxPoolPositions;
//# sourceMappingURL=pool_utils.js.map