UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

80 lines (76 loc) 2.47 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { getSize } from '../../../util/shape'; import { Operation } from '../operation'; export class UpsampleOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { int inIx[${this.maxRank}]; ${this.initIndex('inIx')} for (int i = 0; i < ${this.maxRank}; i++) { if (index[i] == -1) { break; } inIx[i] = int(floor(float(index[i]) / scales[i])); } return _X(inIx); } ${this.getDefaultMain()} `; } getTextureNames() { return ['X']; } getVariables() { return ` ${this.getVarModifier('scales')} float scales[${this.maxRank}]; `; } getUniformAttrs() { return [{ name: 'scales', length: this.maxRank, type: 'float' }]; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { X: input.X }); } const resultShape = this.getOutputShape(input); return this.compute(resultShape, { X: input.X }, { scales: this.copyPad(input.scales), }); } getOutputShape(input) { const rank = input.X.shape.length; const resultShape = [...input.X.shape]; for (let i = 0; i < rank; i++) { resultShape[i] = Math.floor(resultShape[i] * input.scales[i]); } return resultShape; } compile(info) { if (info.shapeX !== undefined) { this.maxRank = info.shapeX.length; } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); return { shapeX: input.X.shape, widthX: input.X.memory.width, heightX: input.X.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, scales: input.scales, }; } getInputInfoString(input) { return `${input.X.shape}-${input.scales}`; } } //# sourceMappingURL=upsample.js.map