UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

26 lines 713 B
export class RepeatBack { constructor(a, repeats) { this.a = a; this.repeats = repeats; } backward(grad) { const shapeA = this.a.getShape(); const gradNewShape = []; const sumAxes = []; for (let i = 0; i < shapeA.length; i++) { gradNewShape.push(this.repeats[i], shapeA[i]); sumAxes.push(i * 2); } const gradA = grad.reshape(gradNewShape, false).sum(sumAxes, false); const needed = this.a.backward(gradA); if (!needed) { gradA.delete(); } } delete() { if (!this.a.isLeaf()) { this.a.delete(); } } } //# sourceMappingURL=repeatBack.js.map