onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
300 lines (272 loc) • 13.9 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
export interface InstanceNormAttributes {
epsilon: number;
format: 'NHWC'|'NCHW';
}
const createInstanceNormProgramInfo =
(inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
const xShape = inputs[0].dims;
const outputShape = xShape;
const axis = 2;
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
const components = getMaxComponents(normSize);
const normPackedSize = normSize / components;
const inputShape = [xShape[0], xShape[1], normPackedSize];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
const programUniforms: ProgramUniform[] =
[{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}];
programUniforms.push(...createTensorShapeVariables(inputShape, inputShape));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const x = inputVariable('x', inputs[0].dataType, inputShape.length, components);
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims);
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
const output = outputVariable('output', inputs[0].dataType, inputShape.length, components);
const variables = [x, scale, bias, output];
const dataType = x.type.value;
const f32Type = components === 1 ? 'f32' : `vec${components}<f32>`;
const workgroupSize = 64;
const uniforms: UniformsArrayType = [{name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}];
return `
var<workgroup> meanShared : f32;
var<workgroup> squaredNormShared : f32;
var<workgroup> workgroupShared : array<${f32Type}, ${workgroupSize}>;
const workgroupSize = ${workgroupSize}u;
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
${shaderHelper.mainStart(workgroupSize)}
let norm = global_idx / workgroupSize;
let batch = norm / uniforms.x_shape[1];
let channel = norm % uniforms.x_shape[1];
let localIndex = local_id.x;
// initialize workgroup memory
var initial = ${f32Type}(0);
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')});
}
workgroupShared[localIndex] = initial;
workgroupBarrier();
// Calculate the mean of current channel data.
for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) {
if (localIndex < currSize) {
workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
}
workgroupBarrier();
}
if (localIndex == 0) {
meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize);
}
workgroupBarrier();
// reinitialize workgroup memory.
initial = ${f32Type}(0);
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared);
initial = initial + deviation * deviation;
}
workgroupShared[localIndex] = initial;
workgroupBarrier();
// Calculate the sum of square of deviation of current channel data.
for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) {
if (localIndex < currSize) {
workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
}
workgroupBarrier();
}
if (localIndex == 0) {
squaredNormShared = ${sumVector('workgroupShared[0]', components)};
}
workgroupBarrier();
let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon}));
let channelScale = invStdDev * f32(${scale.getByOffset('channel')});
let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale;
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${
f32Type}(channelShift));
${output.set('batch', 'channel', 'h', 'value')};
}
}`;
};
return {
...{name: 'InstanceNormalization'},
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
shaderCache: {hint: `${attributes.epsilon};${components}`, inputDependencies},
getRunData: () => ({
outputs: [
{dims: outputShape, dataType: inputs[0].dataType},
],
dispatchGroup: {x: normCount},
programUniforms
}),
getShaderSource,
};
};
const computeMean =
(context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number,
epsilon: number) => {
const components = getMaxComponents(c);
const WG = 64;
// we will store channel scale and channel shift in [2, components] matrix
// or in vec2 when components == 1
const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const sumCastType = components === 1 ? 'f32' : `vec${components}f`;
const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`;
const unitsOfWork = n * c / components;
const wgSize = Math.ceil(h / WG);
const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
const meanProgramUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h},
{type: DataType.uint32, data: Math.floor(c / components)},
{type: DataType.uint32, data: Math.floor(h * c / components)}
];
const getMeanShaderSource = (shaderHelper: ShaderHelper) => {
const inputHelper = inputVariable('input', input.dataType, input.dims, components);
return `
${shaderHelper.declareVariables(inputHelper)}
var<storage, read_write> output : array<${outputType}>;
struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32};
var<uniform> uniforms: Uniforms;
${shaderHelper.mainStart(WG)}
let currentImageNumber = global_idx / ${WG} / uniforms.C;
let currentChannelNumber = (global_idx / ${WG}) % uniforms.C;
let wgId = global_idx % ${WG};
let wgOffset = wgId * uniforms.wg_size;
if (wgOffset >= uniforms.H) {
return;
}
let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H);
let offset = currentImageNumber * uniforms.image_size + currentChannelNumber;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = wgOffset; i < wgMax; i++) {
let value = ${sumCastType}(input[offset + i * uniforms.C]);
sum += value;
squaredSum += value * value;
}
output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
}`;
};
const meanValues = context.compute(
{
name: 'InstanceNormComputeMean',
shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies},
getRunData: () => ({
outputs: [
{dims: [n, c, WG, 2], dataType: DataType.float},
],
dispatchGroup: {x: n * c / components},
programUniforms: meanProgramUniforms
}),
getShaderSource: getMeanShaderSource,
},
{inputs: [input], outputs: [-1]})[0];
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h},
{type: DataType.uint32, data: Math.floor(c / components)},
{type: DataType.uint32, data: Math.floor(WG * c / components)}
];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type'];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);
return `
var<storage, read> input : array<${outputType}>;
var<storage, read> scale : array<${scaleHelper.type.storage}>;
var<storage, read> bias : array<${biasHelper.type.storage}>;
var<storage, read_write> output : array<${outputType}>;
struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32};
var<uniform> uniforms: Uniforms;
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')}
let currentImageNumber = global_idx / uniforms.C;
let currentChannelNumber = global_idx % uniforms.C;
let offset = currentImageNumber * uniforms.image_size;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) {
let value = input[offset + i + currentChannelNumber * ${WG}];
sum += value[0];
squaredSum += value[1];
}
sum = sum / f32(uniforms.H);
squaredSum = squaredSum / f32(uniforms.H);
let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon}));
let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;
output[global_idx] = ${setOutputValue('channelScale', 'channelShift')};
}`;
};
return context.compute(
{
name: 'InstanceNormComputeChannelScaleShift',
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
shaderCache: {hint: `${components};${epsilon}`, inputDependencies},
getRunData: () => ({
outputs: [
{dims: [n, c, 2], dataType: DataType.float},
],
dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
},
{inputs: [meanValues, scale, bias], outputs: [-1]})[0];
};
const createInstanceNormNHWCProgramInfo =
(context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => {
const xShape = inputs[0].dims;
const outputShape = xShape;
const N = xShape[0];
const C = xShape[xShape.length - 1];
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;
const components = getMaxComponents(C);
const outputSize = ShapeUtil.size(outputShape) / components;
const programUniforms: ProgramUniform[] =
[{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
// first compute mean
const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);
const getShaderSource = (shaderHelper: ShaderHelper) => {
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;
const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);
return `
var<storage, read> input : array<${inputHelper.type.storage}>;
var<storage, read> scaleInput : array<${scaleType}>;
var<storage, read_write> output : array<${outputHelper.type.storage}>;
struct Uniforms {H: u32, C : u32};
var<uniform> uniforms: Uniforms;
${shaderHelper.mainStart()}
let currentImageNumber = global_idx / (uniforms.C * uniforms.H);
let currentChannelNumber = global_idx % uniforms.C;
let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber;
let scale = scaleInput[scaleOffset];
output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1]));
}`;
};
context.compute(
{
name: 'InstanceNormalizationNHWC',
shaderCache: {hint: `${components}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
},
{inputs: [inputs[0], channelScaleShift]});
};
export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
if (attributes.format === 'NHWC') {
createInstanceNormNHWCProgramInfo(context, context.inputs, attributes);
} else {
context.compute(createInstanceNormProgramInfo(context.inputs, attributes));
}
};