@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
155 lines (145 loc) • 6.21 kB
text/typescript
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Conv2DInfo} from '../../ops/conv_util';
import * as ops from '../../ops/ops';
import {buffer} from '../../ops/ops';
import {TensorBuffer} from '../../tensor';
import {DataType, Rank, TypedArray} from '../../types';
export function pool(
xValues: TypedArray, xShape: number[], dtype: DataType, strides: number[],
convInfo: Conv2DInfo, poolType: 'max'|'avg'): TensorBuffer<Rank, DataType> {
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const initialValue =
(poolType === 'max' ? Number.NEGATIVE_INFINITY :
Number.POSITIVE_INFINITY);
const output = ops.buffer(convInfo.outShape, dtype);
const outputVals = output.values;
const outputBatchStrides =
convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
const outputColStrides = convInfo.outShape[3];
for (let b = 0; b < convInfo.batchSize; ++b) {
const outputBatchOffset = b * outputBatchStrides;
const inputBatchOffset = b * strides[0];
for (let d = 0; d < convInfo.inChannels; ++d) {
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const xRCorner = yR * strideHeight - padTop;
const xRMin = Math.max(0, xRCorner);
const xRMax =
Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
const outputRowOffset = outputBatchOffset + yR * outputRowStrides;
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const xCCorner = yC * strideWidth - padLeft;
const xCMin = Math.max(0, xCCorner);
const xCMax =
Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
let minMaxValue = initialValue;
let avgValue = 0;
let count = 0;
for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
const xROffset = inputBatchOffset + xR * strides[1];
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
const xCOffset = xROffset + xC * strides[2];
const pixel = xValues[xCOffset + d];
if ((poolType === 'max' && pixel > minMaxValue)) {
minMaxValue = pixel;
} else if (poolType === 'avg') {
avgValue += pixel;
count++;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
const outputOffset = outputRowOffset + yC * outputColStrides + d;
outputVals[outputOffset] =
poolType === 'avg' ? avgValue / count : minMaxValue;
}
}
}
}
return output;
}
export function maxPoolPositions(
xValues: TypedArray, xShape: number[], dtype: DataType,
convInfo: Conv2DInfo, flattenPositions = false,
includeBatchInIndex = false): TensorBuffer<Rank, 'int32'> {
const maxPositions = ops.buffer(convInfo.outShape, 'int32');
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const xBuf = buffer(xShape, dtype, xValues);
for (let b = 0; b < convInfo.batchSize; ++b) {
for (let d = 0; d < convInfo.inChannels; ++d) {
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const xRCorner = yR * strideHeight - padTop;
let xRMin = xRCorner;
while (xRMin < 0) {
xRMin += dilationHeight;
}
// const xRMin = Math.max(0, xRCorner);
const xRMax =
Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const xCCorner = yC * strideWidth - padLeft;
let xCMin = xCCorner;
while (xCMin < 0) {
xCMin += dilationWidth;
}
const xCMax =
Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
let maxValue = Number.NEGATIVE_INFINITY;
let maxPosition = -1;
for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
const wR = xR - xRCorner;
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
const wC = xC - xCCorner;
const pixel = xBuf.get(b, xR, xC, d);
if (pixel > maxValue) {
maxValue = pixel as number;
if (flattenPositions) {
maxPosition = includeBatchInIndex ?
((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *
convInfo.inChannels +
d :
(xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
} else {
maxPosition = wR * effectiveFilterWidth + wC;
}
}
}
}
maxPositions.set(maxPosition, b, yR, yC, d);
}
}
}
}
return maxPositions;
}