UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

35 lines 1.46 kB
import { CPUTensor } from '../../tensor/cpu/tensor'; import { getSize, incrementIndex, indexToPos, computeStrides, } from '../../util/shape'; import { poolResultShape } from '../util/pool'; export function pool(a, axes, operation, keepDims, postProcess) { const inputShape = a.getShape(); const inputSize = getSize(inputShape); const [resultShape, ixMap] = poolResultShape(inputShape, axes, keepDims); const resultSize = getSize(resultShape); const resultStrides = computeStrides(resultShape); const result = new CPUTensor(resultShape, undefined, a.dtype); const initialized = new Array(resultSize).fill(false); const index = new Array(inputShape.length).fill(0); const outIndex = new Array(resultShape.length).fill(0); for (let i = 0; i < inputSize; i++) { for (let j = 0; j < ixMap.length; j++) { outIndex[j] = index[ixMap[j]]; } const outOffset = indexToPos(outIndex, resultStrides); if (initialized[outOffset]) { result.set(outIndex, operation(a.get(i), result.get(outIndex))); } else { initialized[outOffset] = true; result.set(outIndex, operation(a.get(i))); } incrementIndex(index, inputShape); } if (postProcess) { for (let i = 0; i < result.size; i++) { result.set(i, postProcess(result.get(i))); } } return result; } //# sourceMappingURL=pool.js.map