@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
34 lines • 1.08 kB
JavaScript
export class SliceBack {
constructor(a, starts, ends, axes, steps) {
this.a = a;
this.starts = starts;
this.ends = ends;
this.axes = axes;
this.steps = steps;
if (steps.find(x => x !== 1) !== undefined) {
throw new Error('Slice backward pass only supports step size of 1');
}
}
backward(grad) {
if (!this.a.noGrad) {
const shapeA = this.a.getShape();
const rank = shapeA.length;
const pads = new Array(rank * 2).fill(0);
for (let i = 0; i < this.axes.length; i++) {
pads[this.axes[i]] = this.starts[i];
pads[rank + this.axes[i]] = shapeA[this.axes[i]] - this.ends[i];
}
const gradA = grad.pad(pads, 'constant', 0);
const needed = this.a.backward(gradA);
if (!needed) {
gradA.delete();
}
}
}
delete() {
if (!this.a.isLeaf()) {
this.a.delete();
}
}
}
//# sourceMappingURL=sliceBack.js.map