@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
71 lines • 2.17 kB
JavaScript
export class GemmBack {
constructor(a, b, transA, transB, alpha, beta, c) {
this.a = a;
this.b = b;
this.transA = transA;
this.transB = transB;
this.alpha = alpha;
this.beta = beta;
this.c = c;
}
backward(grad) {
if (!this.b.noGrad) {
let gradB;
if (this.transB) {
gradB = grad.gemm(this.a.value, true, this.transA, this.alpha);
}
else {
gradB = this.a.value.gemm(grad, !this.transA, false, this.alpha);
}
const needed = this.b.backward(gradB);
if (!needed) {
gradB.delete();
}
}
if (!this.a.noGrad) {
let gradA;
if (this.transA) {
gradA = this.b.value.gemm(grad, this.transB, true, this.alpha);
}
else {
gradA = grad.gemm(this.b.value, false, !this.transB, this.alpha);
}
const needed = this.a.backward(gradA);
if (!needed) {
gradA.delete();
}
}
if (this.c !== undefined && !this.c.noGrad) {
const gradShape = grad.getShape();
const cShape = this.c.getShape();
const cSumDims = [];
for (let i = 0; i < gradShape.length; i++) {
if (cShape[i] < gradShape[i]) {
cSumDims.push(i);
}
}
let gradC = grad.sum(cSumDims).reshape(cShape, false);
if (this.beta !== 1) {
const oldGradC = gradC;
gradC = gradC.multiplyScalar(this.beta);
oldGradC.delete();
}
const needed = this.c.backward(gradC);
if (!needed) {
gradC.delete();
}
}
}
delete() {
if (!this.a.isLeaf()) {
this.a.delete();
}
if (!this.b.isLeaf()) {
this.b.delete();
}
if (this.c !== undefined && !this.c.isLeaf()) {
this.c.delete();
}
}
}
//# sourceMappingURL=gemmBack.js.map