@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
56 lines • 1.81 kB
JavaScript
export class ConvBack {
constructor(x, w, strides, padding, dilations, group, b) {
this.x = x;
this.w = w;
this.strides = strides;
this.padding = padding;
this.dilations = dilations;
this.group = group;
this.b = b;
}
backward(grad) {
if (!this.w.noGrad) {
const gradW = this.x.value.conv(grad, undefined, this.strides, this.group, this.padding, this.dilations);
const needed = this.w.backward(gradW);
if (!needed) {
gradW.delete();
}
}
if (this.b !== undefined && !this.b.noGrad) {
const biasSum = [0];
for (let i = 0; i < this.dilations.length; i++) {
biasSum.push(i + 2);
}
const gradB = grad.sum(biasSum);
const needed = this.b.backward(gradB);
if (!needed) {
gradB.delete();
}
}
if (!this.x.noGrad) {
const wShape = this.w.getShape();
let xPads = [];
for (let i = 0; i < this.dilations.length; i++) {
xPads.push(wShape[i + 2] - this.padding[i] + this.dilations[i] - 2);
}
xPads = [...xPads, ...xPads];
const gradX = grad.convTranspose(this.w.value, this.dilations, this.group, xPads, this.strides);
const needed = this.x.backward(gradX);
if (!needed) {
gradX.delete();
}
}
}
delete() {
if (!this.x.isLeaf()) {
this.x.delete();
}
if (!this.w.isLeaf()) {
this.w.delete();
}
if (this.b !== undefined && !this.b.isLeaf()) {
this.b.delete();
}
}
}
//# sourceMappingURL=convBack.js.map