UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

71 lines 2.17 kB
export class GemmBack { constructor(a, b, transA, transB, alpha, beta, c) { this.a = a; this.b = b; this.transA = transA; this.transB = transB; this.alpha = alpha; this.beta = beta; this.c = c; } backward(grad) { if (!this.b.noGrad) { let gradB; if (this.transB) { gradB = grad.gemm(this.a.value, true, this.transA, this.alpha); } else { gradB = this.a.value.gemm(grad, !this.transA, false, this.alpha); } const needed = this.b.backward(gradB); if (!needed) { gradB.delete(); } } if (!this.a.noGrad) { let gradA; if (this.transA) { gradA = this.b.value.gemm(grad, this.transB, true, this.alpha); } else { gradA = grad.gemm(this.b.value, false, !this.transB, this.alpha); } const needed = this.a.backward(gradA); if (!needed) { gradA.delete(); } } if (this.c !== undefined && !this.c.noGrad) { const gradShape = grad.getShape(); const cShape = this.c.getShape(); const cSumDims = []; for (let i = 0; i < gradShape.length; i++) { if (cShape[i] < gradShape[i]) { cSumDims.push(i); } } let gradC = grad.sum(cSumDims).reshape(cShape, false); if (this.beta !== 1) { const oldGradC = gradC; gradC = gradC.multiplyScalar(this.beta); oldGradC.delete(); } const needed = this.c.backward(gradC); if (!needed) { gradC.delete(); } } } delete() { if (!this.a.isLeaf()) { this.a.delete(); } if (!this.b.isLeaf()) { this.b.delete(); } if (this.c !== undefined && !this.c.isLeaf()) { this.c.delete(); } } } //# sourceMappingURL=gemmBack.js.map