@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
105 lines (96 loc) • 3.86 kB
JavaScript
import { getSize } from '../../../util/shape';
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { Operation } from '../../../ops/gpu/operation';
import { Dispatcher } from '../../../ops/gpu/dispatcher';
import { gpuConstructor } from '../../../tensor/gpu/tensor';
export class UpdateValueOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
this.maxIterations = 1000000;
}
getVariables() {
return `
${this.getVarModifier('alpha')} float alpha;
${this.getVarModifier('epsilon')} float epsilon;
`;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float newVal(float m1Corr, float m2Corr, float value) {
return value - alpha*(m1Corr/(sqrt(m2Corr)+epsilon));
}
void main() {
initVars();
int pos = coordinateToPos(uv, widthOutput, heightOutput);
vec4 result = vec4(0,0,0,0);
float m1Corr = getValueAtPos(pos*4+1, widthMoments, heightMoments, Moments);
float m2Corr = getValueAtPos(pos*4+3, widthMoments, heightMoments, Moments);
float value = getValueAtPos(pos, widthValue, heightValue, Value);
result.r = newVal(m1Corr, m2Corr, value);
pos++;
m1Corr = getValueAtPos(pos*4+1, widthMoments, heightMoments, Moments);
m2Corr = getValueAtPos(pos*4+3, widthMoments, heightMoments, Moments);
value = getValueAtPos(pos, widthValue, heightValue, Value);
result.g = newVal(m1Corr, m2Corr, value);
pos++;
m1Corr = getValueAtPos(pos*4+1, widthMoments, heightMoments, Moments);
m2Corr = getValueAtPos(pos*4+3, widthMoments, heightMoments, Moments);
value = getValueAtPos(pos, widthValue, heightValue, Value);
result.b = newVal(m1Corr, m2Corr, value);
pos++;
m1Corr = getValueAtPos(pos*4+1, widthMoments, heightMoments, Moments);
m2Corr = getValueAtPos(pos*4+3, widthMoments, heightMoments, Moments);
value = getValueAtPos(pos, widthValue, heightValue, Value);
result.a = newVal(m1Corr, m2Corr, value);
gl_FragColor = result;
}
`;
}
getTextureNames() {
return ['Value', 'Moments'];
}
getUniformAttrs() {
return [
{ name: 'alpha', type: 'float' },
{ name: 'epsilon', type: 'float' },
];
}
calc(input) {
return this.compute(input.Value.shape, { Value: input.Value, Moments: input.Moments }, {
alpha: input.alpha,
epsilon: input.epsilon,
});
}
getOutputShape(input) {
return input.Value.shape;
}
compile(info) {
if (info.shapeMoments !== undefined) {
this.maxRank = info.shapeMoments.length;
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
return {
shapeValue: input.Value.shape,
widthValue: input.Value.memory.width,
heightValue: input.Value.memory.height,
shapeMoments: input.Moments.shape,
widthMoments: input.Moments.memory.width,
heightMoments: input.Moments.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
alpha: input.alpha,
epsilon: input.epsilon,
};
}
getInputInfoString(input) {
return `${input.Value.shape}-${input.Moments.shape}-${input.alpha}-${input.epsilon}`;
}
}
export const defaultUpdateValueD = new Dispatcher((dtype) => new UpdateValueOperation(gpuConstructor, dtype));
//# sourceMappingURL=updateParams.js.map