onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
514 lines (465 loc) • 19 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { env } from 'onnxruntime-common';
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { PoolConvUtil, ShapeUtil } from '../../util';
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
import {
createTensorShapeVariables,
getElementAt,
IndicesHelper,
inputVariable,
outputVariable,
ShaderHelper,
UniformsArrayType,
} from './common';
// TODO: support:
// - ceil_mode "test_maxpool_2d_ceil"
// - storage_order "test_maxpool_with_argmax_2d_precomputed_strides"
// - [MaxPool] dilations "test_maxpool_2d_dilations"
// - [MaxPool] output[1] "test_maxpool_with_argmax_2d_precomputed_pads"
const validateInputs = (inputs: readonly TensorView[]): void => {
if (env.webgpu.validateInputContent && (!inputs || inputs.length !== 1)) {
throw new Error('Pool ops requires 1 input.');
}
};
const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
input: TensorView,
attributes: AttributeType,
isGlobalOperator: boolean,
): [AttributeType, number[]] => {
const isChannelsLast = attributes.format === 'NHWC';
const inputShapeAsChannelFirst = input.dims.slice();
if (isChannelsLast) {
inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
}
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
const kernelShape = attributes.kernelShape.slice();
const strides = attributes.strides.slice();
const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : [];
const pads = attributes.pads.slice();
PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads);
const outputShapeAsChannelFirst = PoolConvUtil.computePoolOutputShape(
isGlobalOperator,
inputShapeAsChannelFirst,
strides,
dilations,
kernelShape,
pads,
attributes.autoPad,
);
const newAttributes = Object.assign({}, attributes);
if (hasDilations) {
Object.assign(newAttributes, { kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey });
} else {
Object.assign(newAttributes, { kernelShape, strides, pads, cacheKey: attributes.cacheKey });
}
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
return [newAttributes, isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst];
};
const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
outputShape: readonly number[],
attributes: AttributeType,
): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => {
const isChannelsLast = attributes.format === 'NHWC';
const outputSize = ShapeUtil.size(outputShape);
const kernelSize = ShapeUtil.size(attributes.kernelShape);
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: kernelSize },
];
const uniforms: UniformsArrayType = [
{ name: 'outputSize', type: 'u32' },
{ name: 'kernelSize', type: 'u32' },
];
if (attributes.kernelShape.length <= 2) {
const kw = attributes.kernelShape[attributes.kernelShape.length - 1];
const sw = attributes.strides[attributes.strides.length - 1];
const pwStart = attributes.pads[attributes.pads.length / 2 - 1];
const pwEnd = attributes.pads[attributes.pads.length - 1];
const pwStartEndNotZero = !!(pwStart + pwEnd);
programUniforms.push(
{ type: DataType.uint32, data: kw },
{ type: DataType.uint32, data: sw },
{ type: DataType.uint32, data: pwStart },
{ type: DataType.uint32, data: pwEnd },
);
uniforms.push(
{ name: 'kw', type: 'u32' },
{ name: 'sw', type: 'u32' },
{ name: 'pwStart', type: 'u32' },
{ name: 'pwEnd', type: 'u32' },
);
let phStartEndNotZero = false;
if (attributes.kernelShape.length === 2) {
const kh = attributes.kernelShape[attributes.kernelShape.length - 2];
const sh = attributes.strides[attributes.strides.length - 2];
const phStart = attributes.pads[attributes.pads.length / 2 - 2];
const phEnd = attributes.pads[attributes.pads.length - 2];
phStartEndNotZero = !!(phStart + phEnd);
programUniforms.push(
{ type: DataType.uint32, data: kh },
{ type: DataType.uint32, data: sh },
{ type: DataType.uint32, data: phStart },
{ type: DataType.uint32, data: phEnd },
);
uniforms.push(
{ name: 'kh', type: 'u32' },
{ name: 'sh', type: 'u32' },
{ name: 'phStart', type: 'u32' },
{ name: 'phEnd', type: 'u32' },
);
}
return [programUniforms, uniforms, true, pwStartEndNotZero, phStartEndNotZero];
} else {
if (isChannelsLast) {
throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
}
const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape);
programUniforms.push(
{ type: DataType.uint32, data: kernelStrides },
{ type: DataType.uint32, data: attributes.pads },
{ type: DataType.uint32, data: attributes.strides },
);
uniforms.push(
{ name: 'kernelStrides', type: 'u32', length: kernelStrides.length },
{ name: 'pads', type: 'u32', length: attributes.pads.length },
{ name: 'strides', type: 'u32', length: attributes.strides.length },
);
const hasPads = attributes.pads.reduce((sum, cur) => sum + cur);
return [programUniforms, uniforms, !!hasPads, false, false];
}
};
const generatePoolingCode = <AttributeType extends AveragePoolAttributes | MaxPoolAttributes>(
shaderHelper: ShaderHelper,
x: IndicesHelper,
rank: number,
outputShapeRank: number,
attributes: AttributeType,
op1: string,
op2: string,
start: number,
uniforms: UniformsArrayType,
hasPads: boolean,
pwStartEndNotZero: boolean,
phStartEndNotZero: boolean,
): string => {
const isChannelsLast = attributes.format === 'NHWC';
const dataType = x.type.value;
const output = outputVariable('output', x.type.tensor, outputShapeRank);
if (attributes.kernelShape.length <= 2) {
let codeW = '';
let codeH = '';
let codeHEnd = '';
const dimIdxW = rank - (isChannelsLast ? 2 : 1);
if (pwStartEndNotZero) {
codeW = `
for (var i: u32 = 0u; i < uniforms.kw; i++) {
xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}]
>= uniforms.x_shape[${dimIdxW}]) {
pad++;
continue;
}
let x_val = x[${x.indicesToOffset('xIndices')}];
${op1}
}`;
} else {
codeW = `
for (var i: u32 = 0u; i < uniforms.kw; i++) {
xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
let x_val = x[${x.indicesToOffset('xIndices')}];
${op1}
}`;
}
if (attributes.kernelShape.length === 2) {
const dimIdxH = rank - (isChannelsLast ? 3 : 2);
if (phStartEndNotZero) {
codeH = `
for (var j: u32 = 0u; j < uniforms.kh; j++) {
xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
if (xIndices[${dimIdxH}] < 0 || xIndices[${dimIdxH}] >= uniforms.x_shape[${dimIdxH}]) {
pad += i32(uniforms.kw);
continue;
}
`;
} else {
codeH = `
for (var j: u32 = 0u; j < uniforms.kh; j++) {
xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
`;
}
codeHEnd = `
}
`;
}
const poolingCode = `
${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let indices = ${output.offsetToIndices('global_idx')};
var xIndices = ${output.offsetToIndices('global_idx')};
var value = ${dataType}(${start});
var pad = 0;
${codeH}
${codeW}
${codeHEnd}
${op2}
output[global_idx] = value;
}`;
return poolingCode;
} else {
if (isChannelsLast) {
throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
}
const stridesRank = attributes.kernelShape.length;
const padsRank = attributes.pads.length;
let padCode = '';
if (hasPads) {
padCode = `
if (xIndices[j] >= uniforms.x_shape[j]) {
pad++;
isPad = true;
break;
}
}
if (!isPad) {
let x_val = x[${x.indicesToOffset('xIndices')}];
${op1}
}`;
} else {
padCode = `
}
let x_val = x[${x.indicesToOffset('xIndices')}];
${op1}
`;
}
const poolingCode = `
${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let indices = ${output.offsetToIndices('global_idx')};
var xIndices = ${output.offsetToIndices('global_idx')};
var offsets: array<u32, ${stridesRank}>;
var value = ${dataType}(${start});
var pad = 0;
var isPad = false;
for (var i: u32 = 0u; i < uniforms.kernelSize; i++) {
var offset = i;
for (var j = 0u; j < ${stridesRank - 1}u; j++) {
offsets[j] = offset / ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
offset -= offsets[j] * ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
}
offsets[${stridesRank - 1}] = offset;
isPad = false;
for (var j = ${rank - stridesRank}u; j < ${rank}u; j++) {
xIndices[j] = indices[j] * ${getElementAt(
'uniforms.strides',
`j - ${rank - stridesRank}u`,
stridesRank,
)}
+ offsets[j - ${rank - stridesRank}u] - ${getElementAt('uniforms.pads', 'j - 2u', padsRank)};
${padCode}
}
${op2}
output[global_idx] = value;
}`;
return poolingCode;
}
};
export interface FormatAttributes {
readonly format: 'NHWC' | 'NCHW';
}
export interface PoolCommonAttributes extends FormatAttributes {
readonly autoPad: string;
readonly ceilMode: number;
readonly kernelShape: readonly number[];
readonly strides: readonly number[];
readonly pads: readonly number[];
}
const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string =>
`${attributes.format};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`;
const createAveragePoolShaderKeyFromAttributes = (attributes: AveragePoolAttributes): string =>
`${createShaderKeyFromAttributes(attributes)};${attributes.countIncludePad}`;
const createMaxPoolShaderKeyFromAttributes = (attributes: MaxPoolAttributes): string =>
`${createShaderKeyFromAttributes(attributes)};${attributes.storageOrder};${attributes.dilations}`;
const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCommonAttributes => ({
format: attributes.format as FormatAttributes['format'],
autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number],
ceilMode: attributes.ceil_mode as number,
kernelShape: attributes.kernel_shape as [number, number],
strides: attributes.strides as [number, number],
pads: attributes.pads as [number, number, number, number],
});
export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
readonly countIncludePad: boolean;
}
const createAveragePoolProgramInfo = (
name: string,
input: TensorView,
isGlobalOperator: boolean,
attributes: AveragePoolAttributes,
): ProgramInfo => {
const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(
input,
attributes,
isGlobalOperator,
);
const x = inputVariable('x', input.dataType, input.dims.length);
const dataType = x.type.value;
const op1 = 'value += x_val;';
let op2 = '';
if (adjustedAttributes.countIncludePad) {
op2 += `value /= ${dataType}(uniforms.kernelSize);`;
} else {
op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`;
}
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(
outputShape,
adjustedAttributes,
);
programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
return {
name,
shaderCache: {
hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`,
inputDependencies,
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: input.dataType }],
dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource: (shaderHelper) =>
generatePoolingCode(
shaderHelper,
x,
input.dims.length,
outputShape.length,
adjustedAttributes,
op1,
op2,
0.0,
uniforms,
hasPads,
pwStartEndNotZero,
phStartEndNotZero,
),
};
};
export const parseAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true;
const attr = parsePoolCommonAttributes(attributes);
// TODO: support attribute 'ceil_mode'
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
}
const averagePoolAttributes = { countIncludePad, ...attr, cacheKey: '' };
return { ...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes) };
};
export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, attributes));
};
const globalPoolAttributes = {
autoPad: '',
ceilMode: 0,
countIncludePad: false,
kernelShape: [],
strides: [],
pads: [],
storageOrder: 0,
dilations: [],
};
export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return { format, ...globalPoolAttributes, cacheKey: format };
};
export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes));
};
export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
readonly storageOrder: number;
readonly dilations: number[];
}
const createMaxPoolProgramInfo = (
name: string,
input: TensorView,
isGlobalOperator: boolean,
attributes: MaxPoolAttributes,
): ProgramInfo => {
const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(
input,
attributes,
isGlobalOperator,
);
const op1 = `
value = max(x_val, value);
`;
const op2 = '';
const x = inputVariable('x', input.dataType, input.dims.length);
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(
outputShape,
adjustedAttributes,
);
programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
return {
name,
shaderCache: {
hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`,
inputDependencies,
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: input.dataType }],
dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource: (shaderHelper) =>
generatePoolingCode(
shaderHelper,
x,
input.dims.length,
outputShape.length,
adjustedAttributes,
op1,
op2,
input.dataType === DataType.float16 ? -65504 : -1e5,
uniforms,
hasPads,
pwStartEndNotZero,
phStartEndNotZero,
),
};
};
export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, attributes));
};
export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const storageOrder = attributes.storage_order as number;
const dilations = attributes.dilations as [number, number];
const attr = parsePoolCommonAttributes(attributes);
// TODO: support attribute 'ceil_mode' and 'storage_order'
if (storageOrder !== 0) {
throw new Error('column major storage order is not yet supported for MaxPool');
}
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
}
const maxPoolAttributes = { storageOrder, dilations, ...attr, cacheKey: '' };
return { ...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes) };
};
export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return { format, ...globalPoolAttributes, cacheKey: format };
};
export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createMaxPoolProgramInfo('GlobalMaxPool', context.inputs[0], true, attributes));
};