UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

46 lines 1.38 kB
import { CPUTensor } from '../../../../tensor/cpu/tensor'; import { WASMTensor } from '../../../../tensor/wasm/tensor'; import { bceBack } from './cpu'; import { defaultBCEBackD } from './gpu'; export class BCEBack { constructor(x, y) { this.x = x; this.y = y; } backward(grad) { let gradX; if (grad instanceof CPUTensor) { const back = bceBack(this.x.value, this.y.value); gradX = grad.multiply(back); back.delete(); } else if (grad instanceof WASMTensor) { const back = this.x.value .wasmTensor.bce_back(this.y.value.wasmTensor); gradX = new WASMTensor(grad.wasmTensor.multiply(back, 1.0), grad.dtype); back.free(); } else { const back = defaultBCEBackD.calc({ A: this.x.value, B: this.y.value, outputShape: this.x.getShape(), }, this.x.value.dtype); gradX = grad.multiply(back); back.delete(); } const needed = this.x.backward(gradX); if (!needed) { gradX.delete(); } } delete() { if (!this.x.isLeaf()) { this.x.delete(); } if (!this.y.isLeaf()) { this.y.delete(); } } } //# sourceMappingURL=back.js.map