UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

80 lines (79 loc) 2.82 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { getSize } from '../../../util/shape'; import { Operation } from '../operation'; export class NormalizeOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } getVariables() { return ` ${this.getVarModifier('epsilon')} float epsilon; `; } getUniformAttrs() { return [{ name: 'epsilon', type: 'float' }]; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int[${this.maxRank}] index) { float result = _X(index) - _Mean(index); result = result / sqrt(_Variance(index) + epsilon); result = result * _Scale(index) + _Bias(index); return result; } ${this.getDefaultMain()} `; } getOutputShape(input) { return input.X.shape; } getTextureNames() { return ['X', 'Mean', 'Variance', 'Scale', 'Bias']; } calc(input) { return this.compute(input.X.shape, { X: input.X, Mean: input.Mean, Variance: input.Variance, Scale: input.Scale, Bias: input.Bias, }, { epsilon: input.epsilon }); } compile(info) { if (info.shapeX !== undefined) { this.maxRank = info.shapeX.length; } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); return { shapeX: input.X.shape, widthX: input.X.memory.width, heightX: input.X.memory.height, shapeBias: input.Bias.shape, widthBias: input.Bias.memory.width, heightBias: input.Bias.memory.height, shapeMean: input.Mean.shape, widthMean: input.Mean.memory.width, heightMean: input.Mean.memory.height, shapeScale: input.Scale.shape, widthScale: input.Scale.memory.width, heightScale: input.Scale.memory.height, shapeVariance: input.Variance.shape, widthVariance: input.Variance.memory.width, heightVariance: input.Variance.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, epsilon: input.epsilon, }; } getInputInfoString(input) { // TODO: Format epsilon with enough precision? return `${input.X.shape}-${input.Mean.shape}-${input.Variance.shape}-${input.Scale.shape}-${input.Bias.shape}-${input.epsilon}`; } } //# sourceMappingURL=normalize.js.map