UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

36 lines 972 B
export class ConcatBack { constructor(a, b, axis) { this.a = a; this.b = b; this.axis = axis; } backward(grad) { let axis = this.axis; if (axis < 0) { axis += this.a.getShape().length; } if (!this.a.noGrad) { const gradA = grad.slice([0], [this.a.getShape()[axis]], [axis]); const needed = this.a.backward(gradA); if (!needed) { gradA.delete(); } } if (!this.b.noGrad) { const gradB = grad.slice([this.a.getShape()[axis]], [grad.getShape()[axis]], [axis]); const needed = this.b.backward(gradB); if (!needed) { gradB.delete(); } } } delete() { if (!this.a.isLeaf()) { this.a.delete(); } if (!this.b.isLeaf()) { this.b.delete(); } } } //# sourceMappingURL=concatBack.js.map