@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
46 lines • 1.38 kB
JavaScript
import { CPUTensor } from '../../../../tensor/cpu/tensor';
import { WASMTensor } from '../../../../tensor/wasm/tensor';
import { bceBack } from './cpu';
import { defaultBCEBackD } from './gpu';
export class BCEBack {
constructor(x, y) {
this.x = x;
this.y = y;
}
backward(grad) {
let gradX;
if (grad instanceof CPUTensor) {
const back = bceBack(this.x.value, this.y.value);
gradX = grad.multiply(back);
back.delete();
}
else if (grad instanceof WASMTensor) {
const back = this.x.value
.wasmTensor.bce_back(this.y.value.wasmTensor);
gradX = new WASMTensor(grad.wasmTensor.multiply(back, 1.0), grad.dtype);
back.free();
}
else {
const back = defaultBCEBackD.calc({
A: this.x.value,
B: this.y.value,
outputShape: this.x.getShape(),
}, this.x.value.dtype);
gradX = grad.multiply(back);
back.delete();
}
const needed = this.x.backward(gradX);
if (!needed) {
gradX.delete();
}
}
delete() {
if (!this.x.isLeaf()) {
this.x.delete();
}
if (!this.y.isLeaf()) {
this.y.delete();
}
}
}
//# sourceMappingURL=back.js.map