UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

277 lines (263 loc) 8.49 kB
import { getSize } from '../../../util/shape'; import { Operation } from '../operation'; import { defaultAllocator } from '../../../tensor/gpu/gl'; export class GemmOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); this.maxIterations = 1000000; } getMainBody() { return ` int ixA[${this.maxRank}]; ${this.initIndex('ixA')} int ixB[${this.maxRank}]; ${this.initIndex('ixA')} for (int i = 0; i < ${this.maxRank}; i++) { if (i >= rank - 2) { break; } ixA[i] = index[i]; ixB[i] = index[i]; } int m = 0; int o = 0; for (int i = 0; i < ${this.maxRank}; i++) { if (i == rank-2) { m = index[i]; o = index[i+1]; if (aTranspose == 0) { ixA[i] = m; } else { ixA[i+1] = m; } if (bTranspose == 0) { ixB[i+1] = o; } else { ixB[i] = o; } break; } } float res = 0.0; for (int n = 0; n < ${this.maxIterations}; n++) { if (n >= N) { break; } for (int i = 0; i < ${this.maxRank}; i++) { if (i == rank-2) { if (aTranspose == 0) { ixA[i+1] = n; } else { ixA[i] = n; } if (bTranspose == 0) { ixB[i] = n; } else { ixB[i+1] = n; } break; } } res += _A(ixA) * _B(ixB); } res = res*alpha; `; } getVariables() { return ` ${this.getVarModifier('M')} int M; ${this.getVarModifier('N')} int N; ${this.getVarModifier('O')} int O; ${this.getVarModifier('rank')} int rank; ${this.getVarModifier('aTranspose')} int aTranspose; ${this.getVarModifier('bTranspose')} int bTranspose; ${this.getVarModifier('alpha')} float alpha; ${this.getVarModifier('beta')} float beta; `; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { ${this.getMainBody()} return res; } ${this.getDefaultMain()} `; } getTextureNames() { return ['A', 'B']; } getUniformAttrs() { return [ { name: 'M' }, { name: 'N' }, { name: 'O' }, { name: 'rank' }, { name: 'aTranspose' }, { name: 'bTranspose' }, { name: 'alpha', type: 'float' }, { name: 'beta', type: 'float' }, ]; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { A: input.a, B: input.b }); } const rank = input.a.shape.length; const M = input.aTranspose ? input.a.shape[rank - 1] : input.a.shape[rank - 2]; const N = input.aTranspose ? input.a.shape[rank - 2] : input.a.shape[rank - 1]; const O = input.bTranspose ? input.b.shape[rank - 2] : input.b.shape[rank - 1]; const batchShape = input.a.shape.slice(0, rank - 2); const resultShape = [...batchShape, M, O]; const uniforms = { M, N, O, rank, aTranspose: input.aTranspose ? 1 : 0, bTranspose: input.bTranspose ? 1 : 0, alpha: input.alpha, beta: input.beta, }; return this.compute(resultShape, { A: input.a, B: input.b }, uniforms); } getOutputShape(input) { const rank = input.a.shape.length; const M = input.aTranspose ? input.a.shape[rank - 1] : input.a.shape[rank - 2]; const O = input.bTranspose ? input.b.shape[rank - 2] : input.b.shape[rank - 1]; const batchShape = input.a.shape.slice(0, rank - 2); const resultShape = [...batchShape, M, O]; return resultShape; } compile(info) { if (info.shapeA !== undefined) { const rank = info.shapeA.length; info.rank = rank; this.maxRank = rank; if (info.aTranspose !== undefined) { const M = info.aTranspose ? info.shapeA[rank - 1] : info.shapeA[rank - 2]; const N = info.aTranspose ? info.shapeA[rank - 2] : info.shapeA[rank - 1]; info.M = M; info.N = N; info.aTranspose = info.aTranspose ? 1 : 0; } } if (info.shapeB !== undefined && info.bTranspose !== undefined) { const rank = info.shapeB.length; const O = info.bTranspose ? info.shapeB[rank - 2] : info.shapeB[rank - 1]; info.O = O; info.bTranspose = info.bTranspose ? 1 : 0; } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); const rank = input.a.shape.length; const M = input.aTranspose ? input.a.shape[rank - 1] : input.a.shape[rank - 2]; const N = input.aTranspose ? input.a.shape[rank - 2] : input.a.shape[rank - 1]; const O = input.bTranspose ? input.b.shape[rank - 2] : input.b.shape[rank - 1]; const info = { shapeA: input.a.shape, widthA: input.a.memory.width, heightA: input.a.memory.height, shapeB: input.b.shape, widthB: input.b.memory.width, heightB: input.b.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, M, N, O, aTranspose: input.aTranspose ? 1 : 0, bTranspose: input.bTranspose ? 1 : 0, alpha: input.alpha, beta: input.beta, rank, }; return info; } getInputInfoString(input) { //TODO: Check precision of alpha and beta return `${input.a.shape}-${input.b.shape}-${input.aTranspose}-${input.bTranspose}-${input.alpha}-${input.beta}`; } } export class GemmCOperation extends GemmOperation { getTextureNames() { return ['A', 'B', 'C']; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { ${this.getMainBody()} res += beta*_C(index); return res; } ${this.getDefaultMain()} `; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { A: input.a, B: input.b, C: input.c, }); } const rank = input.a.shape.length; const M = input.aTranspose ? input.a.shape[rank - 1] : input.a.shape[rank - 2]; const N = input.aTranspose ? input.a.shape[rank - 2] : input.a.shape[rank - 1]; const O = input.bTranspose ? input.b.shape[rank - 2] : input.b.shape[rank - 1]; const batchShape = input.a.shape.slice(0, rank - 2); const resultShape = [...batchShape, M, O]; const uniforms = { M, N, O, rank, aTranspose: input.aTranspose ? 1 : 0, bTranspose: input.bTranspose ? 1 : 0, alpha: input.alpha, beta: input.beta, }; return this.compute(resultShape, { A: input.a, B: input.b, C: input.c }, uniforms); } getCompilationInfo(input) { const inf = super.getCompilationInfo(input); const info = Object.assign(Object.assign({}, inf), { shapeC: input.c.shape, widthC: input.c.memory.width, heightC: input.c.memory.height }); return info; } getInputInfoString(input) { return `${super.getInputInfoString(input)}-${input.c.shape}`; } } //# sourceMappingURL=gemm.js.map