@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
82 lines (75 loc) • 2.34 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { getSize } from '../../../util/shape';
import { Operation } from '../operation';
export class MatMulOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
this.maxIterations = 1000000;
this.maxRank = 2;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
int ix1[${this.maxRank}];
${this.initIndex('ix1')}
ix1[0] = index[0];
int ix2[${this.maxRank}];
${this.initIndex('ix2')}
ix2[1] = index[1];
int k = shapeA[1];
float res = 0.0;
for (int i = 0; i < ${this.maxIterations}; i++) {
if (i >= k) {
break;
}
ix1[1] = i;
ix2[0] = i;
float v1 = _A(ix1);
float v2 = _B(ix2);
res += v1*v2;
}
return res;
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['A', 'B'];
}
calc(input) {
const outputShape = this.getOutputShape(input);
return this.compute(outputShape, { A: input.A, B: input.B });
}
getOutputShape(input) {
return [input.A.shape[0], input.B.shape[1]];
}
compile(info) {
if (info.shapeA !== undefined) {
this.maxIterations = info.shapeA[1];
}
else if (info.shapeB !== undefined) {
this.maxIterations = info.shapeB[0];
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
return {
shapeA: input.A.shape,
widthA: input.A.memory.width,
heightA: input.A.memory.height,
shapeB: input.A.shape,
widthB: input.A.memory.width,
heightB: input.A.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
};
}
getInputInfoString(input) {
return `${input.A.shape}-${input.B.shape}`;
}
}
//# sourceMappingURL=matmul.js.map