UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

105 lines (96 loc) 3.86 kB
import { getSize } from '../../../util/shape'; import { defaultAllocator } from '../../../tensor/gpu/gl'; import { Operation } from '../../../ops/gpu/operation'; import { Dispatcher } from '../../../ops/gpu/dispatcher'; import { gpuConstructor } from '../../../tensor/gpu/tensor'; export class UpdateValueOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); this.maxIterations = 1000000; } getVariables() { return ` ${this.getVarModifier('alpha')} float alpha; ${this.getVarModifier('epsilon')} float epsilon; `; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float newVal(float m1Corr, float m2Corr, float value) { return value - alpha*(m1Corr/(sqrt(m2Corr)+epsilon)); } void main() { initVars(); int pos = coordinateToPos(uv, widthOutput, heightOutput); vec4 result = vec4(0,0,0,0); float m1Corr = getValueAtPos(pos*4+1, widthMoments, heightMoments, Moments); float m2Corr = getValueAtPos(pos*4+3, widthMoments, heightMoments, Moments); float value = getValueAtPos(pos, widthValue, heightValue, Value); result.r = newVal(m1Corr, m2Corr, value); pos++; m1Corr = getValueAtPos(pos*4+1, widthMoments, heightMoments, Moments); m2Corr = getValueAtPos(pos*4+3, widthMoments, heightMoments, Moments); value = getValueAtPos(pos, widthValue, heightValue, Value); result.g = newVal(m1Corr, m2Corr, value); pos++; m1Corr = getValueAtPos(pos*4+1, widthMoments, heightMoments, Moments); m2Corr = getValueAtPos(pos*4+3, widthMoments, heightMoments, Moments); value = getValueAtPos(pos, widthValue, heightValue, Value); result.b = newVal(m1Corr, m2Corr, value); pos++; m1Corr = getValueAtPos(pos*4+1, widthMoments, heightMoments, Moments); m2Corr = getValueAtPos(pos*4+3, widthMoments, heightMoments, Moments); value = getValueAtPos(pos, widthValue, heightValue, Value); result.a = newVal(m1Corr, m2Corr, value); gl_FragColor = result; } `; } getTextureNames() { return ['Value', 'Moments']; } getUniformAttrs() { return [ { name: 'alpha', type: 'float' }, { name: 'epsilon', type: 'float' }, ]; } calc(input) { return this.compute(input.Value.shape, { Value: input.Value, Moments: input.Moments }, { alpha: input.alpha, epsilon: input.epsilon, }); } getOutputShape(input) { return input.Value.shape; } compile(info) { if (info.shapeMoments !== undefined) { this.maxRank = info.shapeMoments.length; } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); return { shapeValue: input.Value.shape, widthValue: input.Value.memory.width, heightValue: input.Value.memory.height, shapeMoments: input.Moments.shape, widthMoments: input.Moments.memory.width, heightMoments: input.Moments.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, alpha: input.alpha, epsilon: input.epsilon, }; } getInputInfoString(input) { return `${input.Value.shape}-${input.Moments.shape}-${input.alpha}-${input.epsilon}`; } } export const defaultUpdateValueD = new Dispatcher((dtype) => new UpdateValueOperation(gpuConstructor, dtype)); //# sourceMappingURL=updateParams.js.map