UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

108 lines (102 loc) 3.57 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { gpuConstructor } from '../../../tensor/gpu/tensor'; import { computeStrides, getSize } from '../../../util/shape'; import { Dispatcher } from '../../gpu/dispatcher'; import { Operation } from '../../gpu/operation'; export class RepeatIndexOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } getVariables() { return ` ${this.getVarModifier('sparseShape')} int sparseShape[${this.maxRank}]; ${this.getVarModifier('repeatStrides')} int repeatStrides[${this.maxRank}]; `; } getUniformAttrs() { return [ { name: 'sparseShape', length: this.maxRank }, { name: 'repeatStrides', length: this.maxRank }, ]; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { int newPos = index[0]; int nnz = shapeA[0]; int repeatPos = newPos / nnz; int oldPos = newPos - repeatPos*nnz; int oldIx[${this.maxRank}]; ${this.initIndex('oldIx')} oldIx[0] = oldPos; oldIx[1] = index[1]; int repeatIx[${this.maxRank}]; ${this.initIndex('repeatIx')} ${this.posToIndex('repeatStrides', 'repeatIx', 'repeatPos')} float res = _A(oldIx); for (int i = 0; i < ${this.maxRank}; i++) { if (i == index[1]) { res += float(repeatIx[i]*sparseShape[i]); break; } } return res; } ${this.getDefaultMain()} `; } getTextureNames() { return ['A']; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { A: input.A }); } const outputShape = this.getOutputShape(input); const info = this.getCompilationInfo(input); return this.compute(outputShape, { A: input.A }, { sparseShape: this.pad(info.sparseShape), repeatStrides: this.pad(info.repeatStrides), }); } getOutputShape(input) { return [input.A.shape[0] * input.repeatsProd, input.A.shape[1]]; } compile(info) { if (info.sparseShape !== undefined) { this.maxRank = info.sparseShape.length; } if (info.repeatStrides !== undefined) { this.maxRank = info.repeatStrides.length; } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); return { shapeA: input.A.shape, widthA: input.A.memory.width, heightA: input.A.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, sparseShape: input.shape, repeatStrides: computeStrides(input.repeats), }; } getInputInfoString(input) { return `${input.A.shape}-${input.repeats}-${input.shape}`; } } export const defaultRepeatIndexD = new Dispatcher((dtype) => new RepeatIndexOperation(gpuConstructor, dtype)); export function repeatIndexGPU(indices, repeats, shape, repeatsProd) { return defaultRepeatIndexD.calc({ A: indices, repeats, shape, repeatsProd, }, 'uint32'); } //# sourceMappingURL=gpu.js.map