UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

166 lines (154 loc) 5.38 kB
import { getSize } from '../../../util/shape'; import { outputDimsSize } from '../../util/conv'; import { Operation } from '../operation'; import { defaultAllocator } from '../../../tensor/gpu/gl'; export class AveragePoolOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); this.maxIterations = 1000000; } updateInputIx() { return ` for (int d = 0; d < ${this.maxRank - 2}; d++) { int stride = strides[d]; int pad = pads[d]; if (stride == -1) { break; } inputIx[d+2] = index[d+2]*stride - pad + kernelIx[d]; if (inputIx[d+2] < 0 || inputIx[d+2] >= shapeX[d+2]) { skip = true; break; } } `; } getVariables() { return ` ${this.getVarModifier('kernelSize')} int kernelSize; ${this.getVarModifier('dataRank')} int dataRank; ${this.getVarModifier('includePad')} int includePad; ${this.getVarModifier('pads')} int pads[${this.maxRank}]; ${this.getVarModifier('strides')} int strides[${this.maxRank}]; ${this.getVarModifier('kernelShape')} int kernelShape[${this.maxRank}]; `; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { float res = 0.0; int count = 0; int n = index[0]; int c = index[1]; int kernelIx[${this.maxRank}]; ${this.initIndex('kernelIx')} for (int i = 0; i < ${this.maxRank}; i++) { if (i >= dataRank) { break; } kernelIx[i] = 0; } int inputIx[${this.maxRank}]; ${this.initIndex('inputIx')} inputIx[0] = n; inputIx[1] = c; for (int kIx = 0; kIx < ${this.maxIterations}; kIx++) { if (kIx >= kernelSize) { break; } bool skip = false; ${this.updateInputIx()} if (!skip) { res += _X(inputIx); } if (!skip || includePad == 1) { count += 1; } ${this.incrementIndex('kernelIx', 'kernelShape')} } res = res / float(count); return res; } ${this.getDefaultMain()} `; } getTextureNames() { return ['X']; } getUniformAttrs() { return [ { name: 'kernelSize' }, { name: 'dataRank' }, { name: 'includePad' }, { name: 'pads', length: this.maxRank * 2 }, { name: 'strides', length: this.maxRank }, { name: 'kernelShape', length: this.maxRank }, ]; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { X: input.X }); } const N = input.X.shape[0]; const C = input.X.shape[1]; const D = input.X.shape.slice(2); const kernelSize = getSize(input.kernelShape); const R = outputDimsSize(D, input.kernelShape, input.pads.slice(0, input.pads.length / 2), input.pads.slice(input.pads.length / 2), new Array(D.length).fill(1), input.strides); let outputShape = [N, C]; outputShape = outputShape.concat(R); return this.compute(outputShape, { X: input.X }, { kernelSize, includePad: input.includePad ? 1 : 0, dataRank: D.length, pads: this.copyPad(input.pads, this.maxRank * 2), strides: this.copyPad(input.strides), kernelShape: this.copyPad(input.kernelShape), }); } getOutputShape(input) { const N = input.X.shape[0]; const C = input.X.shape[1]; const D = input.X.shape.slice(2); const R = outputDimsSize(D, input.kernelShape, input.pads.slice(0, input.pads.length / 2), input.pads.slice(input.pads.length / 2), new Array(D.length).fill(1), input.strides); let outputShape = [N, C]; outputShape = outputShape.concat(R); return outputShape; } compile(info) { if (info.shapeX !== undefined) { info.dataRank = info.shapeX.length - 2; this.maxRank = info.shapeX.length; } if (info.includePad === true) { info.includePad = 1; } else if (info.includePad === false) { info.includePad = 0; } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); const kernelSize = getSize(input.kernelShape); return { shapeX: input.X.shape, widthX: input.X.memory.width, heightX: input.X.memory.height, kernelShape: input.kernelShape, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, pads: input.pads, strides: input.strides, kernelSize: kernelSize, dataRank: input.X.shape.length - 2, includePad: input.includePad ? 1 : 0, }; } getInputInfoString(input) { return `${input.X.shape}-${input.kernelShape}-${input.pads}-${input.strides}-${input.includePad}`; } } //# sourceMappingURL=averagePool.js.map