@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
108 lines (102 loc) • 3.57 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { gpuConstructor } from '../../../tensor/gpu/tensor';
import { computeStrides, getSize } from '../../../util/shape';
import { Dispatcher } from '../../gpu/dispatcher';
import { Operation } from '../../gpu/operation';
export class RepeatIndexOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
}
getVariables() {
return `
${this.getVarModifier('sparseShape')} int sparseShape[${this.maxRank}];
${this.getVarModifier('repeatStrides')} int repeatStrides[${this.maxRank}];
`;
}
getUniformAttrs() {
return [
{ name: 'sparseShape', length: this.maxRank },
{ name: 'repeatStrides', length: this.maxRank },
];
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
int newPos = index[0];
int nnz = shapeA[0];
int repeatPos = newPos / nnz;
int oldPos = newPos - repeatPos*nnz;
int oldIx[${this.maxRank}];
${this.initIndex('oldIx')}
oldIx[0] = oldPos;
oldIx[1] = index[1];
int repeatIx[${this.maxRank}];
${this.initIndex('repeatIx')}
${this.posToIndex('repeatStrides', 'repeatIx', 'repeatPos')}
float res = _A(oldIx);
for (int i = 0; i < ${this.maxRank}; i++) {
if (i == index[1]) {
res += float(repeatIx[i]*sparseShape[i]);
break;
}
}
return res;
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['A'];
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { A: input.A });
}
const outputShape = this.getOutputShape(input);
const info = this.getCompilationInfo(input);
return this.compute(outputShape, { A: input.A }, {
sparseShape: this.pad(info.sparseShape),
repeatStrides: this.pad(info.repeatStrides),
});
}
getOutputShape(input) {
return [input.A.shape[0] * input.repeatsProd, input.A.shape[1]];
}
compile(info) {
if (info.sparseShape !== undefined) {
this.maxRank = info.sparseShape.length;
}
if (info.repeatStrides !== undefined) {
this.maxRank = info.repeatStrides.length;
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
return {
shapeA: input.A.shape,
widthA: input.A.memory.width,
heightA: input.A.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
sparseShape: input.shape,
repeatStrides: computeStrides(input.repeats),
};
}
getInputInfoString(input) {
return `${input.A.shape}-${input.repeats}-${input.shape}`;
}
}
export const defaultRepeatIndexD = new Dispatcher((dtype) => new RepeatIndexOperation(gpuConstructor, dtype));
export function repeatIndexGPU(indices, repeats, shape, repeatsProd) {
return defaultRepeatIndexD.calc({
A: indices,
repeats,
shape,
repeatsProd,
}, 'uint32');
}
//# sourceMappingURL=gpu.js.map