onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
150 lines (145 loc) • 5.57 kB
JavaScript
;
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
Object.defineProperty(exports, '__esModule', { value: true });
exports.parseInstanceNormalizationAttributes = exports.instanceNormalization = void 0;
const glsl_source_1 = require('../glsl-source');
const types_1 = require('../types');
const instanceNormalization = (inferenceHandler, inputs, epsilon) => {
validateInputs(inputs);
const meanAndVariance = inferenceHandler.run(createMeanAndVarianceProgramInfoLoader(inputs[0]), inputs);
const output = inferenceHandler.run(
createComputeOutputProgramInfoLoader(inferenceHandler, inputs[0], epsilon, meanAndVariance.dims),
[inputs[0], meanAndVariance, inputs[1], inputs[2]],
);
return [output];
};
exports.instanceNormalization = instanceNormalization;
const parseInstanceNormalizationAttributes = (node) => node.attributes.getFloat('epsilon', 1e-5);
exports.parseInstanceNormalizationAttributes = parseInstanceNormalizationAttributes;
const meanAndVarianceProgramMetadata = {
name: 'InstanceNormalization_MeanAndVariance',
inputNames: ['X'],
inputTypes: [types_1.TextureType.unpacked],
};
const createMeanAndVarianceProgramInfo = (metadata, input) => {
const xDims = input.dims.slice();
const channel = xDims[1];
const channelSize = xDims[2] * xDims[3];
const outputShape = [xDims[0], channel];
const shaderSource = `
vec4 process(int[2] indices) {
vec4 v = vec4(0.0);
int a[4];
a[0] = indices[0];
a[1] = indices[1];
float temp = 0.0;
for(int a2=0; a2<${xDims[2]}; a2++) {
a[2] = a2;
for(int a3=0; a3<${xDims[3]}; a3++) {
a[3] = a3;
float x = _X(a);
temp += x;
}
}
float mean = temp / float(${channelSize});
temp = 0.0;
for(int a2=0; a2<${xDims[2]}; a2++) {
a[2] = a2;
for(int a3=0; a3<${xDims[3]}; a3++) {
a[3] = a3;
float x = _X(a);
temp += (x - mean) * (x - mean);
}
}
v.r = mean;
v.g = temp / float(${channelSize});
return v;
}`;
return {
...metadata,
output: { dims: outputShape, type: input.type, textureType: types_1.TextureType.packedLastDimension },
shaderSource,
};
};
const createMeanAndVarianceProgramInfoLoader = (input) => ({
...meanAndVarianceProgramMetadata,
get: () => createMeanAndVarianceProgramInfo(meanAndVarianceProgramMetadata, input),
});
const computeOutputProgramMetadata = {
name: 'InstanceNormalization_ComputeOutput',
inputNames: ['X', 'MeanAndVariance', 'Scale', 'B'],
inputTypes: [
types_1.TextureType.unpacked,
types_1.TextureType.packedLastDimension,
types_1.TextureType.unpacked,
types_1.TextureType.unpacked,
],
};
const createComputeOutputProgramInfo = (inferenceHandler, metadata, input, epsilon, meanAndVarianceShape) => {
const glsl = (0, glsl_source_1.getGlsl)(inferenceHandler.session.backend.glContext.version);
const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
meanAndVarianceShape,
types_1.TextureType.packedLastDimension,
);
const [meanAndVarianceWidth, meanAndVarianceHeight] = [textureWidth / 4, textureHeight];
const shaderSource = `
vec4 get_MeanAndVariance(int[2] mv) {
int offset = indicesToOffset_MeanAndVariance(mv);
vec2 coords = offsetToCoords(offset, ${meanAndVarianceWidth}, ${meanAndVarianceHeight});
return ${glsl.texture2D}(MeanAndVariance, coords);
}
float process(int[4] indices) {
int mv[2];
mv[0] = indices[0];
mv[1] = indices[1];
vec4 mean_and_variance = get_MeanAndVariance(mv);
float mean = mean_and_variance.r;
float variance = mean_and_variance.g;
int sb[1];
sb[0] = indices[1];
float scale = _Scale(sb);
float b = _B(sb);
return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b;
}`;
return {
...metadata,
output: { dims: input.dims, type: input.type, textureType: types_1.TextureType.unpacked },
variables: [{ name: 'epsilon', type: 'float', data: epsilon }],
shaderSource,
};
};
const createComputeOutputProgramInfoLoader = (inferenceHandler, input, epsilon, meanAndVarianceShape) => {
const metadata = { ...computeOutputProgramMetadata, cacheHint: `${epsilon}` };
return {
...metadata,
get: () => createComputeOutputProgramInfo(inferenceHandler, metadata, input, epsilon, meanAndVarianceShape),
};
};
const validateInputs = (inputs) => {
if (!inputs || inputs.length !== 3) {
throw new Error('InstanceNormalization requires 3 inputs.');
}
const X = inputs[0];
const scale = inputs[1];
const B = inputs[2];
// input should at least have three dimensions - N,C,dim1,...,dimn
// other inputs can have only one dimensions
if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1) {
throw new Error('Invalid input shape.');
}
if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) {
throw new Error('Input shapes are mismatched.');
}
if (
(X.type !== 'float32' && X.type !== 'float64') ||
(scale.type !== 'float32' && scale.type !== 'float64') ||
(B.type !== 'float32' && B.type !== 'float64')
) {
throw new Error('Invalid input type.');
}
if (inputs[0].dims.length !== 4) {
throw new Error('Only support 4-D input shape.');
}
};
//# sourceMappingURL=instance-normalization.js.map