@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
36 lines • 972 B
JavaScript
export class ConcatBack {
constructor(a, b, axis) {
this.a = a;
this.b = b;
this.axis = axis;
}
backward(grad) {
let axis = this.axis;
if (axis < 0) {
axis += this.a.getShape().length;
}
if (!this.a.noGrad) {
const gradA = grad.slice([0], [this.a.getShape()[axis]], [axis]);
const needed = this.a.backward(gradA);
if (!needed) {
gradA.delete();
}
}
if (!this.b.noGrad) {
const gradB = grad.slice([this.a.getShape()[axis]], [grad.getShape()[axis]], [axis]);
const needed = this.b.backward(gradB);
if (!needed) {
gradB.delete();
}
}
}
delete() {
if (!this.a.isLeaf()) {
this.a.delete();
}
if (!this.b.isLeaf()) {
this.b.delete();
}
}
}
//# sourceMappingURL=concatBack.js.map