UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

173 lines (154 loc) 5.82 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { Graph } from '../../../graph'; import { OperatorImplementation, OperatorInitialization } from '../../../operators'; import { Tensor } from '../../../tensor'; import { getGlsl } from '../glsl-source'; import { WebGLInferenceHandler } from '../inference-handler'; import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; export const instanceNormalization: OperatorImplementation<number> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], epsilon: number, ): Tensor[] => { validateInputs(inputs); const meanAndVariance = inferenceHandler.run(createMeanAndVarianceProgramInfoLoader(inputs[0]), inputs); const output = inferenceHandler.run( createComputeOutputProgramInfoLoader(inferenceHandler, inputs[0], epsilon, meanAndVariance.dims), [inputs[0], meanAndVariance, inputs[1], inputs[2]], ); return [output]; }; export const parseInstanceNormalizationAttributes: OperatorInitialization<number> = (node: Graph.Node): number => node.attributes.getFloat('epsilon', 1e-5); const meanAndVarianceProgramMetadata = { name: 'InstanceNormalization_MeanAndVariance', inputNames: ['X'], inputTypes: [TextureType.unpacked], }; const createMeanAndVarianceProgramInfo = (metadata: ProgramMetadata, input: Tensor): ProgramInfo => { const xDims = input.dims.slice(); const channel = xDims[1]; const channelSize = xDims[2] * xDims[3]; const outputShape = [xDims[0], channel]; const shaderSource = ` vec4 process(int[2] indices) { vec4 v = vec4(0.0); int a[4]; a[0] = indices[0]; a[1] = indices[1]; float temp = 0.0; for(int a2=0; a2<${xDims[2]}; a2++) { a[2] = a2; for(int a3=0; a3<${xDims[3]}; a3++) { a[3] = a3; float x = _X(a); temp += x; } } float mean = temp / float(${channelSize}); temp = 0.0; for(int a2=0; a2<${xDims[2]}; a2++) { a[2] = a2; for(int a3=0; a3<${xDims[3]}; a3++) { a[3] = a3; float x = _X(a); temp += (x - mean) * (x - mean); } } v.r = mean; v.g = temp / float(${channelSize}); return v; }`; return { ...metadata, output: { dims: outputShape, type: input.type, textureType: TextureType.packedLastDimension }, shaderSource, }; }; const createMeanAndVarianceProgramInfoLoader = (input: Tensor): ProgramInfoLoader => ({ ...meanAndVarianceProgramMetadata, get: () => createMeanAndVarianceProgramInfo(meanAndVarianceProgramMetadata, input), }); const computeOutputProgramMetadata = { name: 'InstanceNormalization_ComputeOutput', inputNames: ['X', 'MeanAndVariance', 'Scale', 'B'], inputTypes: [TextureType.unpacked, TextureType.packedLastDimension, TextureType.unpacked, TextureType.unpacked], }; const createComputeOutputProgramInfo = ( inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, input: Tensor, epsilon: number, meanAndVarianceShape: readonly number[], ): ProgramInfo => { const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight( meanAndVarianceShape, TextureType.packedLastDimension, ); const [meanAndVarianceWidth, meanAndVarianceHeight] = [textureWidth / 4, textureHeight]; const shaderSource = ` vec4 get_MeanAndVariance(int[2] mv) { int offset = indicesToOffset_MeanAndVariance(mv); vec2 coords = offsetToCoords(offset, ${meanAndVarianceWidth}, ${meanAndVarianceHeight}); return ${glsl.texture2D}(MeanAndVariance, coords); } float process(int[4] indices) { int mv[2]; mv[0] = indices[0]; mv[1] = indices[1]; vec4 mean_and_variance = get_MeanAndVariance(mv); float mean = mean_and_variance.r; float variance = mean_and_variance.g; int sb[1]; sb[0] = indices[1]; float scale = _Scale(sb); float b = _B(sb); return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b; }`; return { ...metadata, output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked }, variables: [{ name: 'epsilon', type: 'float', data: epsilon }], shaderSource, }; }; const createComputeOutputProgramInfoLoader = ( inferenceHandler: WebGLInferenceHandler, input: Tensor, epsilon: number, meanAndVarianceShape: readonly number[], ): ProgramInfoLoader => { const metadata = { ...computeOutputProgramMetadata, cacheHint: `${epsilon}` }; return { ...metadata, get: () => createComputeOutputProgramInfo(inferenceHandler, metadata, input, epsilon, meanAndVarianceShape), }; }; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 3) { throw new Error('InstanceNormalization requires 3 inputs.'); } const X = inputs[0]; const scale = inputs[1]; const B = inputs[2]; // input should at least have three dimensions - N,C,dim1,...,dimn // other inputs can have only one dimensions if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1) { throw new Error('Invalid input shape.'); } if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) { throw new Error('Input shapes are mismatched.'); } if ( (X.type !== 'float32' && X.type !== 'float64') || (scale.type !== 'float32' && scale.type !== 'float64') || (B.type !== 'float32' && B.type !== 'float64') ) { throw new Error('Invalid input type.'); } if (inputs[0].dims.length !== 4) { throw new Error('Only support 4-D input shape.'); } };