UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

158 lines (150 loc) 5.45 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { computeStrides, getSize } from '../../../util/shape'; import { poolResultShape } from '../../util/pool'; import { Operation } from '../operation'; export class PoolOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); this.maxIterations = 1000000; } getVariables() { return ` ${this.getVarModifier('mappedInputStrides')} int mappedInputStrides[${this.maxRank}]; ${this.getVarModifier('mappedInputStrides')} int sumDims[${this.maxRank}]; ${this.getVarModifier('mappedInputStrides')} int sumSize; `; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { int inputIx[${this.maxRank}]; ${this.initIndex('inputIx')} int inputPos = 0; for (int i = 0; i < ${this.maxRank}; i++) { if (mappedInputStrides[i] == -1 || index[i] == -1) { break; } inputPos += mappedInputStrides[i]*index[i]; } ${this.posToIndex('stridesX', 'inputIx', 'inputPos')} float res = 0.0; for (int i = 0; i < ${this.maxIterations}; i++) { if (i >= sumSize) { break; } float curr = _X(inputIx); if (i == 0) { res = ${this.init('curr')}; } else { res = ${this.update('curr', 'res')}; } ${this.incrementConditional('inputIx', 'shapeX', 'sumDims')} } ${this.post('res')} return res; } ${this.getDefaultMain()} `; } // eslint-disable-next-line @typescript-eslint/no-unused-vars post(res) { return ''; } init(res) { return res; } getTextureNames() { return ['X']; } getUniformAttrs() { return [ { name: 'mappedInputStrides', length: this.maxRank }, { name: 'sumDims', length: this.maxRank }, { name: 'sumSize' }, ]; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { X: input.X }); } const [outputShape, ixMap] = poolResultShape(input.X.shape, input.axes, input.keepDims); const inputStrides = computeStrides(input.X.shape); const mappedInputStrides = []; for (const i of ixMap) { mappedInputStrides.push(inputStrides[i]); } let sumSize = 1; const sumDims = new Array(input.X.shape.length).fill(0); for (let i = 0; i < input.axes.length; i++) { sumDims[input.axes[i]] = 1; sumSize *= input.X.shape[input.axes[i]]; } return this.compute(outputShape, { X: input.X }, { mappedInputStrides: this.pad(mappedInputStrides), sumDims: this.pad(sumDims), sumSize, }); } getOutputShape(input) { // eslint-disable-next-line @typescript-eslint/no-unused-vars const [outputShape, ixMap] = poolResultShape(input.X.shape, input.axes, input.keepDims); return outputShape; } compile(info) { if (info.shapeX !== undefined && info.axes !== undefined && info.keepDims !== undefined) { const [outputShape, ixMap] = poolResultShape(info.shapeX, info.axes, info.keepDims); const inputStrides = computeStrides(info.shapeX); const mappedInputStrides = []; for (const i of ixMap) { mappedInputStrides.push(inputStrides[i]); } let sumSize = 1; const sumDims = new Array(info.shapeX.length).fill(0); for (let i = 0; i < info.axes.length; i++) { sumDims[info.axes[i]] = 1; sumSize *= info.shapeX[info.axes[i]]; } info.sumDims = sumDims; info.shapeOutput = outputShape; info.mappedInputStrides = mappedInputStrides; info.sumSize = sumSize; delete info['keepDims']; delete info['axes']; this.maxRank = info.shapeX.length; } super.compile(info); } getCompilationInfo(input) { const [outputShape, ixMap] = poolResultShape(input.X.shape, input.axes, input.keepDims); const inputStrides = computeStrides(input.X.shape); const mappedInputStrides = []; for (const i of ixMap) { mappedInputStrides.push(inputStrides[i]); } let sumSize = 1; const sumDims = new Array(input.X.shape.length).fill(0); for (let i = 0; i < input.axes.length; i++) { sumDims[input.axes[i]] = 1; sumSize *= input.X.shape[input.axes[i]]; } const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); return { shapeX: input.X.shape, widthX: input.X.memory.width, heightX: input.X.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, mappedInputStrides, sumDims, sumSize, }; } getInputInfoString(input) { return `${input.X.shape}-${input.axes}-${input.keepDims}`; } } //# sourceMappingURL=pool.js.map