@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
85 lines (84 loc) • 2.91 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { Operation } from '../operation';
export class ClipBackwardOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
float val = _X(index);
if (doMin == 1 && val < minVal) {
return 0.0;
}
if (doMax == 1 && val > maxVal) {
return 0.0;
}
return _Grad(index);
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['X', 'Grad'];
}
getVariables() {
return `
${this.getVarModifier('minVal')} float minVal;
${this.getVarModifier('maxVal')} float maxVal;
${this.getVarModifier('doMin')} int doMin;
${this.getVarModifier('doMax')} int doMax;
`;
}
getUniformAttrs() {
return [
{ name: 'minVal', type: 'float' },
{ name: 'maxVal', type: 'float' },
{ name: 'doMin' },
{ name: 'doMax' },
];
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { X: input.input, Grad: input.grad });
}
return this.compute(input.input.shape, { X: input.input, Grad: input.grad }, {
minVal: input.minVal !== undefined ? input.minVal : 0,
maxVal: input.maxVal !== undefined ? input.maxVal : 0,
doMin: input.minVal !== undefined ? 1 : 0,
doMax: input.maxVal !== undefined ? 1 : 0,
});
}
getOutputShape(input) {
return input.input.shape;
}
compile(info) {
if (info.shapeX !== undefined) {
this.maxRank = info.shapeX.length;
}
super.compile(info);
}
getCompilationInfo(input) {
const outputSize = defaultAllocator.getAllocationDimensions(input.input.size, this.dtype);
return {
shapeX: input.input.shape,
widthX: input.input.memory.width,
heightX: input.input.memory.height,
shapeGrad: input.grad.shape,
widthGrad: input.grad.memory.width,
heightGrad: input.grad.memory.height,
shapeOutput: input.input.shape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
minVal: input.minVal !== undefined ? input.minVal : 0,
maxVal: input.maxVal !== undefined ? input.maxVal : 0,
doMin: input.minVal !== undefined ? 1 : 0,
doMax: input.maxVal !== undefined ? 1 : 0,
};
}
getInputInfoString(input) {
return `${input.input.shape}-${input.minVal}-${input.maxVal}`;
}
}
//# sourceMappingURL=clipBackward.js.map