@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
39 lines • 1.21 kB
JavaScript
export class LogSumBack {
constructor(input, sumDims, keepDims) {
this.input = input;
this.sumDims = sumDims;
this.keepDims = keepDims;
}
backward(grad) {
const inShape = this.input.value.getShape();
const sum = this.input.value.sum(this.sumDims, this.keepDims);
let gradLogSum = grad.divide(sum);
sum.delete();
if (!this.keepDims) {
const newShape = [];
let sumI = 0;
for (let i = 0; i < inShape.length; i++) {
if (sumI < this.sumDims.length && this.sumDims[sumI] === i) {
newShape.push(1);
sumI++;
}
else {
newShape.push(inShape[i]);
}
}
gradLogSum = gradLogSum.reshape(newShape, false);
}
const expanded = gradLogSum.expand(inShape);
gradLogSum.delete();
const needed = this.input.backward(expanded);
if (!needed) {
expanded.delete();
}
}
delete() {
if (!this.input.isLeaf()) {
this.input.delete();
}
}
}
//# sourceMappingURL=logSumBack.js.map