UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

27 lines 788 B
export class ExpandBack { constructor(a, shape) { this.a = a; this.shape = shape; } backward(grad) { // eslint-disable-next-line @typescript-eslint/no-unused-vars const [_shape, goal, resultShape] = this.a.value.alignShapes(this.a.getShape(), this.shape); const sumDims = []; for (let i = 0; i < _shape.length; i++) { if (_shape[i] < goal[i]) { sumDims.push(i); } } const gradA = grad.sum(sumDims).reshape(this.a.getShape()); const needed = this.a.backward(gradA); if (!needed) { gradA.delete(); } } delete() { if (!this.a.isLeaf()) { this.a.delete(); } } } //# sourceMappingURL=expandBack.js.map