UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

56 lines 1.81 kB
export class ConvBack { constructor(x, w, strides, padding, dilations, group, b) { this.x = x; this.w = w; this.strides = strides; this.padding = padding; this.dilations = dilations; this.group = group; this.b = b; } backward(grad) { if (!this.w.noGrad) { const gradW = this.x.value.conv(grad, undefined, this.strides, this.group, this.padding, this.dilations); const needed = this.w.backward(gradW); if (!needed) { gradW.delete(); } } if (this.b !== undefined && !this.b.noGrad) { const biasSum = [0]; for (let i = 0; i < this.dilations.length; i++) { biasSum.push(i + 2); } const gradB = grad.sum(biasSum); const needed = this.b.backward(gradB); if (!needed) { gradB.delete(); } } if (!this.x.noGrad) { const wShape = this.w.getShape(); let xPads = []; for (let i = 0; i < this.dilations.length; i++) { xPads.push(wShape[i + 2] - this.padding[i] + this.dilations[i] - 2); } xPads = [...xPads, ...xPads]; const gradX = grad.convTranspose(this.w.value, this.dilations, this.group, xPads, this.strides); const needed = this.x.backward(gradX); if (!needed) { gradX.delete(); } } } delete() { if (!this.x.isLeaf()) { this.x.delete(); } if (!this.w.isLeaf()) { this.w.delete(); } if (this.b !== undefined && !this.b.isLeaf()) { this.b.delete(); } } } //# sourceMappingURL=convBack.js.map