UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

85 lines (84 loc) 2.91 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { Operation } from '../operation'; export class ClipBackwardOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { float val = _X(index); if (doMin == 1 && val < minVal) { return 0.0; } if (doMax == 1 && val > maxVal) { return 0.0; } return _Grad(index); } ${this.getDefaultMain()} `; } getTextureNames() { return ['X', 'Grad']; } getVariables() { return ` ${this.getVarModifier('minVal')} float minVal; ${this.getVarModifier('maxVal')} float maxVal; ${this.getVarModifier('doMin')} int doMin; ${this.getVarModifier('doMax')} int doMax; `; } getUniformAttrs() { return [ { name: 'minVal', type: 'float' }, { name: 'maxVal', type: 'float' }, { name: 'doMin' }, { name: 'doMax' }, ]; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { X: input.input, Grad: input.grad }); } return this.compute(input.input.shape, { X: input.input, Grad: input.grad }, { minVal: input.minVal !== undefined ? input.minVal : 0, maxVal: input.maxVal !== undefined ? input.maxVal : 0, doMin: input.minVal !== undefined ? 1 : 0, doMax: input.maxVal !== undefined ? 1 : 0, }); } getOutputShape(input) { return input.input.shape; } compile(info) { if (info.shapeX !== undefined) { this.maxRank = info.shapeX.length; } super.compile(info); } getCompilationInfo(input) { const outputSize = defaultAllocator.getAllocationDimensions(input.input.size, this.dtype); return { shapeX: input.input.shape, widthX: input.input.memory.width, heightX: input.input.memory.height, shapeGrad: input.grad.shape, widthGrad: input.grad.memory.width, heightGrad: input.grad.memory.height, shapeOutput: input.input.shape, widthOutput: outputSize.width, heightOutput: outputSize.height, minVal: input.minVal !== undefined ? input.minVal : 0, maxVal: input.maxVal !== undefined ? input.maxVal : 0, doMin: input.minVal !== undefined ? 1 : 0, doMax: input.maxVal !== undefined ? 1 : 0, }; } getInputInfoString(input) { return `${input.input.shape}-${input.minVal}-${input.maxVal}`; } } //# sourceMappingURL=clipBackward.js.map