UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

154 lines (143 loc) 6.7 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { env } from 'onnxruntime-common'; import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; import { ComputeContext, ProgramInfo } from '../types'; import { createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper } from './common'; export interface BatchNormAttributes extends AttributeWithCacheKey { readonly epsilon: number; readonly momentum: number; readonly spatial: boolean; readonly trainingMode: boolean; readonly format: 'NHWC' | 'NCHW'; readonly outputCount: number; } const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttributes): void => { if (!inputs || inputs.length !== 5) { throw new Error('BatchNormalization requires 5 inputs'); } const checkShapeEqual = (actual: readonly number[], expected: readonly number[], message: string) => { const r = expected.length; if (r !== actual.length) { throw new Error(`${message}: num dimensions != ${r}`); } expected.forEach((v, i) => { if (v !== actual[i]) { throw new Error(`${message}: dim[${i}] do not match`); } }); }; if (inputs[0].dims.length > 1) { const shape = attributes.format === 'NHWC' ? attributes.spatial ? inputs[0].dims.slice(-1) : inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1)) : inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined); checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale'); checkShapeEqual(inputs[2].dims, shape, 'Invalid input B'); checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean'); checkShapeEqual(inputs[4].dims, shape, 'Invalid input var'); } else { checkShapeEqual(inputs[1].dims, [1], 'Invalid input scale'); checkShapeEqual(inputs[2].dims, [1], 'Invalid input B'); checkShapeEqual(inputs[3].dims, [1], 'Invalid input mean'); checkShapeEqual(inputs[4].dims, [1], 'Invalid input var'); } }; const createBatchNormInferenceProgramInfo = ( inputs: readonly TensorView[], attributes: BatchNormAttributes, ): ProgramInfo => { const { epsilon, spatial, format } = attributes; const yShape = inputs[0].dims; const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1; const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; const outputSize = ShapeUtil.size(yShape) / components; // Only support uniforms for opset version >= 9 (spatial = true). const useShapesUniforms = spatial; const shapeOrRank = useShapesUniforms ? yShape.length : yShape; const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents); const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents); const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents); const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components); // TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type. // Otherwise, the shader compilation will fail. const calcCOffset = (): string => { let cOffset = ''; if (spatial) { cOffset = `let cOffset = ${ yShape.length === 1 ? '0u' : format === 'NHWC' ? `outputIndices[${yShape.length - 1}] / ${components}` : 'outputIndices[1]' };`; } else { if (format === 'NCHW') { cOffset = ` ${y.indicesSet('outputIndices', '0', '0')} let cOffset = ${y.indicesToOffset('outputIndices')};`; } else { // update C channel. cOffset = `var cIndices = ${scale.type.indices}(0); cIndices[0] = outputIndices[${yShape.length - 1}];`; // update D1 x ... x Dn channels. for (let i = 1; i < scale.rank; i++) { cOffset += `cIndices[${i}] = outputIndices[${i}];`; } cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`; } } return cOffset; }; const getInferenceModeShaderSource = (helper: ShaderHelper) => ` const epsilon = ${epsilon}; ${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)} ${helper.mainStart()} ${helper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var outputIndices = ${y.offsetToIndices(`global_idx * ${components}`)}; ${calcCOffset()} let scale = ${scale.getByOffset('cOffset')}; let bias = ${bias.getByOffset('cOffset')}; let inputMean = ${inputMean.getByOffset('cOffset')}; let inputVar = ${inputVar.getByOffset('cOffset')}; let x = ${x.getByOffset('global_idx')}; let value = (x - inputMean) * inverseSqrt(inputVar + epsilon) * scale + bias; ${y.setByOffset('global_idx', 'value')} }`; return { name: 'BatchNormalization', shaderCache: { hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`, inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined, }, getShaderSource: getInferenceModeShaderSource, getRunData: () => ({ outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms: useShapesUniforms ? [{ type: DataType.uint32, data: outputSize }, ...createTensorShapeVariables(yShape)] : [{ type: DataType.uint32, data: outputSize }], }), }; }; export const parseBatchNormAttributes = (attributes: Record<string, unknown>): BatchNormAttributes => createAttributeWithCacheKey(attributes as Omit<BatchNormAttributes, keyof AttributeWithCacheKey>); export const batchNorm = (context: ComputeContext, attributes: Record<string, unknown>): void => { const { inputs, outputCount } = context; const updatedAttributes = parseBatchNormAttributes({ ...attributes, outputCount }); if (env.webgpu.validateInputContent) { validateInputs(inputs, updatedAttributes); } if (attributes.trainingMode) { throw new Error('BatchNormalization trainingMode is not supported yet.'); } else { context.compute(createBatchNormInferenceProgramInfo(inputs, updatedAttributes)); } };