UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

81 lines (78 loc) 2.5 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { gpuConstructor } from '../../../tensor/gpu/tensor'; import { getSize } from '../../../util/shape'; import { Dispatcher } from '../../gpu/dispatcher'; import { Operation } from '../../gpu/operation'; export class AddIndexOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } getVariables() { return ` ${this.getVarModifier('axis')} int axis; ${this.getVarModifier('count')} int count; `; } getUniformAttrs() { return [{ name: 'axis' }, { name: 'count' }]; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { float res = _A(index); if (index[1] == axis) { res += float(count); } return res; } ${this.getDefaultMain()} `; } getTextureNames() { return ['A']; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { A: input.A }); } const outputShape = this.getOutputShape(input); const info = this.getCompilationInfo(input); return this.compute(outputShape, { A: input.A }, { axis: info.axis, count: info.count, }); } getOutputShape(input) { return [...input.A.shape]; } compile(info) { super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); return { shapeA: input.A.shape, widthA: input.A.memory.width, heightA: input.A.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, axis: input.axis, count: input.count, }; } getInputInfoString(input) { return `${input.A.shape}-${input.axis}-${input.count}`; } } export const defaultAddIndexD = new Dispatcher((dtype) => new AddIndexOperation(gpuConstructor, dtype)); export function addIndexGPU(indices, axis, count) { return defaultAddIndexD.calc({ A: indices, axis, count, }, 'uint32'); } //# sourceMappingURL=gpu.js.map