@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
81 lines • 2.73 kB
JavaScript
export class PowerBack {
constructor(a, b, powerResult, shape) {
this.a = a;
this.b = b;
this.powerResult = powerResult;
this.shape = shape;
}
backward(grad) {
const shapeA = this.a.getShape();
const shapeB = this.b.getShape();
const sumADims = [];
const sumBDims = [];
for (let i = 0; i < this.shape.length; i++) {
if (shapeA[i] < this.shape[i]) {
sumADims.push(i);
}
if (shapeB[i] < this.shape[i]) {
sumBDims.push(i);
}
}
if (!this.a.noGrad) {
let gradA;
if (sumADims.length === 0) {
const multiplied = this.powerResult.multiply(this.b.value);
const divided = multiplied.divide(this.a.value);
multiplied.delete();
const gradPowA = grad.multiply(divided);
divided.delete();
gradA = gradPowA.reshape(shapeA, false);
}
else {
const multiplied = this.powerResult.multiply(this.b.value);
const divided = multiplied.divide(this.a.value);
multiplied.delete();
const gradPowA = grad.multiply(divided);
divided.delete();
const summed = gradPowA.sum(sumADims);
gradPowA.delete();
gradA = summed.reshape(shapeA, false);
}
const needed = this.a.backward(gradA);
if (!needed) {
gradA.delete();
}
}
if (!this.b.noGrad) {
let gradB;
if (sumBDims.length === 0) {
const lnA = this.a.value.log();
const mult = this.powerResult.multiply(lnA);
lnA.delete();
gradB = grad.multiply(mult);
mult.delete();
gradB = gradB.reshape(shapeB, false);
}
else {
const lnA = this.a.value.log();
const mult = this.powerResult.multiply(lnA);
lnA.delete();
const _gradB = grad.multiply(mult);
mult.delete();
const summed = _gradB.sum(sumBDims);
_gradB.delete();
gradB = summed.reshape(shapeB, false);
}
const needed = this.b.backward(gradB);
if (!needed) {
gradB.delete();
}
}
}
delete() {
if (!this.a.isLeaf()) {
this.a.delete();
}
if (!this.b.isLeaf()) {
this.b.delete();
}
}
}
//# sourceMappingURL=powerBack.js.map