@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
27 lines • 788 B
JavaScript
export class ExpandBack {
constructor(a, shape) {
this.a = a;
this.shape = shape;
}
backward(grad) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [_shape, goal, resultShape] = this.a.value.alignShapes(this.a.getShape(), this.shape);
const sumDims = [];
for (let i = 0; i < _shape.length; i++) {
if (_shape[i] < goal[i]) {
sumDims.push(i);
}
}
const gradA = grad.sum(sumDims).reshape(this.a.getShape());
const needed = this.a.backward(gradA);
if (!needed) {
gradA.delete();
}
}
delete() {
if (!this.a.isLeaf()) {
this.a.delete();
}
}
}
//# sourceMappingURL=expandBack.js.map