@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
41 lines • 1.25 kB
JavaScript
export class MeanBack {
constructor(input, sumDims, keepDims) {
this.input = input;
this.sumDims = sumDims;
this.keepDims = keepDims;
}
backward(grad) {
const inShape = this.input.value.getShape();
if (!this.keepDims) {
const newShape = [];
let sumI = 0;
for (let i = 0; i < inShape.length; i++) {
if (sumI < this.sumDims.length && this.sumDims[sumI] === i) {
newShape.push(1);
sumI++;
}
else {
newShape.push(inShape[i]);
}
}
grad = grad.reshape(newShape, false);
}
let sumSize = 1;
for (let i = 0; i < this.sumDims.length; i++) {
sumSize *= inShape[this.sumDims[i]];
}
const multiplied = grad.multiplyScalar(1 / sumSize);
const gradIn = multiplied.expand(inShape);
multiplied.delete;
const needed = this.input.backward(gradIn);
if (!needed) {
grad.delete();
}
}
delete() {
if (!this.input.isLeaf()) {
this.input.delete();
}
}
}
//# sourceMappingURL=meanBack.js.map