@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
26 lines • 713 B
JavaScript
export class RepeatBack {
constructor(a, repeats) {
this.a = a;
this.repeats = repeats;
}
backward(grad) {
const shapeA = this.a.getShape();
const gradNewShape = [];
const sumAxes = [];
for (let i = 0; i < shapeA.length; i++) {
gradNewShape.push(this.repeats[i], shapeA[i]);
sumAxes.push(i * 2);
}
const gradA = grad.reshape(gradNewShape, false).sum(sumAxes, false);
const needed = this.a.backward(gradA);
if (!needed) {
gradA.delete();
}
}
delete() {
if (!this.a.isLeaf()) {
this.a.delete();
}
}
}
//# sourceMappingURL=repeatBack.js.map