@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
80 lines (79 loc) • 2.82 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { getSize } from '../../../util/shape';
import { Operation } from '../operation';
export class NormalizeOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
}
getVariables() {
return `
${this.getVarModifier('epsilon')} float epsilon;
`;
}
getUniformAttrs() {
return [{ name: 'epsilon', type: 'float' }];
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int[${this.maxRank}] index) {
float result = _X(index) - _Mean(index);
result = result / sqrt(_Variance(index) + epsilon);
result = result * _Scale(index) + _Bias(index);
return result;
}
${this.getDefaultMain()}
`;
}
getOutputShape(input) {
return input.X.shape;
}
getTextureNames() {
return ['X', 'Mean', 'Variance', 'Scale', 'Bias'];
}
calc(input) {
return this.compute(input.X.shape, {
X: input.X,
Mean: input.Mean,
Variance: input.Variance,
Scale: input.Scale,
Bias: input.Bias,
}, { epsilon: input.epsilon });
}
compile(info) {
if (info.shapeX !== undefined) {
this.maxRank = info.shapeX.length;
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
return {
shapeX: input.X.shape,
widthX: input.X.memory.width,
heightX: input.X.memory.height,
shapeBias: input.Bias.shape,
widthBias: input.Bias.memory.width,
heightBias: input.Bias.memory.height,
shapeMean: input.Mean.shape,
widthMean: input.Mean.memory.width,
heightMean: input.Mean.memory.height,
shapeScale: input.Scale.shape,
widthScale: input.Scale.memory.width,
heightScale: input.Scale.memory.height,
shapeVariance: input.Variance.shape,
widthVariance: input.Variance.memory.width,
heightVariance: input.Variance.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
epsilon: input.epsilon,
};
}
getInputInfoString(input) {
// TODO: Format epsilon with enough precision?
return `${input.X.shape}-${input.Mean.shape}-${input.Variance.shape}-${input.Scale.shape}-${input.Bias.shape}-${input.epsilon}`;
}
}
//# sourceMappingURL=normalize.js.map