@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
63 lines • 1.93 kB
JavaScript
export class MultiplyBack {
constructor(a, b, shape, alpha) {
this.a = a;
this.b = b;
this.shape = shape;
this.alpha = alpha;
}
backward(grad) {
const shapeA = this.a.getShape();
const shapeB = this.b.getShape();
const sumADims = [];
const sumBDims = [];
for (let i = 0; i < this.shape.length; i++) {
if (shapeA[i] < this.shape[i]) {
sumADims.push(i);
}
if (shapeB[i] < this.shape[i]) {
sumBDims.push(i);
}
}
if (!this.a.noGrad) {
let gradA;
if (sumADims.length === 0) {
gradA = grad.multiply(this.b.value, this.alpha).reshape(shapeA, false);
}
else {
const mult = grad.multiply(this.b.value, this.alpha);
const summed = mult.sum(sumADims);
mult.delete();
gradA = summed.reshape(shapeA, false);
}
const needed = this.a.backward(gradA);
if (!needed) {
gradA.delete();
}
}
if (!this.b.noGrad) {
let gradB;
if (sumBDims.length === 0) {
gradB = grad.multiply(this.a.value, this.alpha).reshape(shapeB, false);
}
else {
const mult = grad.multiply(this.a.value, this.alpha);
const summed = mult.sum(sumBDims);
mult.delete();
gradB = summed.reshape(shapeB, false);
}
const needed = this.b.backward(gradB);
if (!needed) {
gradB.delete();
}
}
}
delete() {
if (!this.a.isLeaf()) {
this.a.delete();
}
if (!this.b.isLeaf()) {
this.b.delete();
}
}
}
//# sourceMappingURL=multiplyBack.js.map