UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

136 lines (130 loc) 4.2 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { getSize } from '../../../util/shape'; import { Operation } from '../operation'; export class PadOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { int inputIx[${this.maxRank}]; ${this.initIndex('inputIx')} if (mode == 0) { float res = value; int outOfBounds = 0; for (int i = 0; i < ${this.maxRank}; i++) { if (index[i] == -1) { break; } inputIx[i] = index[i] - pads[i]; if (inputIx[i] < 0 || inputIx[i] >= shapeX[i]) { outOfBounds = 1; break; } } if (outOfBounds == 0) { res = _X(inputIx); } return res; } else if (mode == 1) { for (int i = 0; i < ${this.maxRank}; i++) { if (index[i] == -1) { break; } inputIx[i] = index[i] - pads[i]; if (inputIx[i] < 0) { inputIx[i] = -inputIx[i]; } else if (inputIx[i] >= shapeX[i]) { inputIx[i] = 2*shapeX[i] - inputIx[i] - 2; } } return _X(inputIx); } else { for (int i = 0; i < ${this.maxRank}; i++) { if (index[i] == -1) { break; } inputIx[i] = index[i] - pads[i]; if (inputIx[i] < 0) { inputIx[i] = 0; } else if (inputIx[i] >= shapeX[i]) { inputIx[i] = shapeX[i] - 1; } } return _X(inputIx); } } ${this.getDefaultMain()} `; } getTextureNames() { return ['X']; } getVariables() { return ` ${this.getVarModifier('pads')} int pads[${this.maxRank * 2}]; ${this.getVarModifier('value')} float value; ${this.getVarModifier('mode')} int mode; `; } getUniformAttrs() { return [ { name: 'value', type: 'float' }, { name: 'pads', length: this.maxRank * 2 }, { name: 'mode' }, ]; } getModeFlag(mode) { return mode === 'constant' ? 0 : mode === 'reflect' ? 1 : 2; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { X: input.input }); } const resultShape = this.getOutputShape(input); return this.compute(resultShape, { X: input.input }, { pads: this.copyPad(input.pads, this.maxRank * 2), value: input.value, mode: this.getModeFlag(input.mode), }); } getOutputShape(input) { const rank = input.input.shape.length; const resultShape = [...input.input.shape]; for (let i = 0; i < rank; i++) { resultShape[i] += input.pads[i] + input.pads[i + rank]; } return resultShape; } compile(info) { if (info.shapeX !== undefined) { this.maxRank = info.shapeX.length; } if (info.mode !== undefined && typeof info.mode === 'string') { info.mode = this.getModeFlag(info.mode); } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); return { shapeX: input.input.shape, widthX: input.input.memory.width, heightX: input.input.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, pads: input.pads, mode: this.getModeFlag(input.mode), value: input.value, }; } getInputInfoString(input) { //TODO: Format value with enough precision? return `${input.input.shape}-${input.pads}-${input.mode}-${input.value}`; } } //# sourceMappingURL=pad.js.map