UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

34 lines 1.08 kB
export class SliceBack { constructor(a, starts, ends, axes, steps) { this.a = a; this.starts = starts; this.ends = ends; this.axes = axes; this.steps = steps; if (steps.find(x => x !== 1) !== undefined) { throw new Error('Slice backward pass only supports step size of 1'); } } backward(grad) { if (!this.a.noGrad) { const shapeA = this.a.getShape(); const rank = shapeA.length; const pads = new Array(rank * 2).fill(0); for (let i = 0; i < this.axes.length; i++) { pads[this.axes[i]] = this.starts[i]; pads[rank + this.axes[i]] = shapeA[this.axes[i]] - this.ends[i]; } const gradA = grad.pad(pads, 'constant', 0); const needed = this.a.backward(gradA); if (!needed) { gradA.delete(); } } } delete() { if (!this.a.isLeaf()) { this.a.delete(); } } } //# sourceMappingURL=sliceBack.js.map