onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
505 lines (473 loc) • 19.4 kB
text/typescript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv3d_naive_webgpu.ts
//
// modified to fit the needs of the project
import { DataType } from '../../../../wasm-common';
import { LOG_DEBUG } from '../../../log';
import { TensorView } from '../../../tensor-view';
import { ShapeUtil } from '../../../util';
import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types';
import {
createTensorShapeVariables,
getElementAt,
inputVariable,
outputVariable,
ShaderHelper,
tensorTypeToWsglStorageType,
UniformsArrayType,
} from '../common';
import { ConvAttributes } from '../conv';
import { appendActivationUniforms, appendActivationUniformsData, getActivationSnippet } from '../fuse-utils';
import { typeSnippet } from './activation_util';
const arrayProduct = (arr: number[]) => {
let product = 1;
for (let i = 0; i < arr.length; i++) {
product *= arr[i];
}
return product;
};
const parse3TupleParam = (param: number | [number, number, number]): [number, number, number] =>
typeof param === 'number' ? [param, param, param] : param;
const getEffectiveFilterSize = (filterSize: number, dilation: number): number => {
if (dilation <= 1) {
return filterSize;
}
return filterSize + (filterSize - 1) * (dilation - 1);
};
const computeDefaultPad = (
inputShape: [number, number] | [number, number, number, number],
fieldSize: number,
stride: number,
dilation = 1,
): number => {
const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
};
const computeOutputShape4D = (
inShape: [number, number, number, number],
filterShape: [number, number, number],
outChannels: number,
strides: [number, number, number],
zeroPad?: number,
): [number, number, number, number] => {
if (zeroPad == null) {
// eslint-disable-next-line no-param-reassign
zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]);
}
const outShape: [number, number, number, number] = [0, 0, 0, outChannels];
for (let index = 0; index < 3; index++) {
if (inShape[index] + 2 * zeroPad >= filterShape[index]) {
outShape[index] = Math.trunc((inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] + 1);
}
}
return outShape;
};
const get3DPadAndOutInfo = (
pad: number | string | number[],
inDepth: number,
inHeight: number,
inWidth: number,
strideDepth: number,
strideHeight: number,
strideWidth: number,
filterDepth: number,
filterHeight: number,
filterWidth: number,
): { padInfo: PadInfo3D; outDepth: number; outHeight: number; outWidth: number } => {
let padInfo: PadInfo3D;
let outDepth: number;
let outHeight: number;
let outWidth: number;
if (pad === 'VALID') {
// eslint-disable-next-line no-param-reassign
pad = 0;
}
if (typeof pad === 'number') {
padInfo = { top: pad, bottom: pad, left: pad, right: pad, front: pad, back: pad };
const outShape = computeOutputShape4D(
[inDepth, inHeight, inWidth, 1],
[filterDepth, filterHeight, filterWidth],
1,
[strideDepth, strideHeight, strideWidth],
pad,
);
outDepth = outShape[0];
outHeight = outShape[1];
outWidth = outShape[2];
} else if (Array.isArray(pad)) {
if (!pad.every((val, _, arr) => val === arr[0])) {
throw Error(`Unsupported padding parameter: ${pad}`);
}
padInfo = { top: pad[0], bottom: pad[1], left: pad[2], right: pad[3], front: pad[4], back: pad[5] };
const outShape = computeOutputShape4D(
[inDepth, inHeight, inWidth, 1],
[filterDepth, filterHeight, filterWidth],
1,
[strideDepth, strideHeight, strideWidth],
pad[0],
);
outDepth = outShape[0];
outHeight = outShape[1];
outWidth = outShape[2];
} else if (pad === 'SAME_UPPER') {
// TODO: support 'SAME_LOWER'.
outDepth = Math.ceil(inDepth / strideDepth);
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
const front = Math.floor(padAlongDepth / 2);
const back = padAlongDepth - front;
const top = Math.floor(padAlongHeight / 2);
const bottom = padAlongHeight - top;
const left = Math.floor(padAlongWidth / 2);
const right = padAlongWidth - left;
padInfo = { top, bottom, left, right, front, back };
} else {
throw Error(`Unknown padding parameter: ${pad}`);
}
return { padInfo, outDepth, outHeight, outWidth };
};
type PadInfo3D = {
top: number;
left: number;
right: number;
bottom: number;
front: number;
back: number;
};
export type Conv3DInfo = {
batchSize: number;
inDepth: number;
inHeight: number;
inWidth: number;
inChannels: number;
outDepth: number;
outHeight: number;
outWidth: number;
outChannels: number;
dataFormat: 'channelsFirst' | 'channelsLast';
strideDepth: number;
strideHeight: number;
strideWidth: number;
dilationDepth: number;
dilationHeight: number;
dilationWidth: number;
filterDepth: number;
filterHeight: number;
filterWidth: number;
effectiveFilterDepth: number;
effectiveFilterHeight: number;
effectiveFilterWidth: number;
padInfo: PadInfo3D;
inShape: [number, number, number, number, number];
outShape: [number, number, number, number, number];
filterShape: [number, number, number, number, number];
};
export const computeConv3DInfo = (
inShape: [number, number, number, number, number],
filterShape: [number, number, number, number, number],
strides: number | [number, number, number],
dilations: number | [number, number, number],
pad: number | string | number[],
depthwise = false,
dataFormat: 'channelsFirst' | 'channelsLast' = 'channelsLast',
): Conv3DInfo => {
let batchSize, inDepth, inHeight, inWidth, inChannels;
if (dataFormat === 'channelsLast') {
[batchSize, inDepth, inHeight, inWidth, inChannels] = inShape;
} else if (dataFormat === 'channelsFirst') {
[batchSize, inChannels, inDepth, inHeight, inWidth] = inShape;
} else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
const [filterChannels, , filterDepth, filterHeight, filterWidth] = filterShape;
const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides);
const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations);
const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
const { padInfo, outDepth, outHeight, outWidth } = get3DPadAndOutInfo(
pad,
inDepth,
inHeight,
inWidth,
strideDepth,
strideHeight,
strideWidth,
effectiveFilterDepth,
effectiveFilterHeight,
effectiveFilterWidth,
);
const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
let outShape: [number, number, number, number, number] = [0, 0, 0, 0, 0];
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
} else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
}
return {
batchSize,
dataFormat,
inDepth,
inHeight,
inWidth,
inChannels,
outDepth,
outHeight,
outWidth,
outChannels,
padInfo,
strideDepth,
strideHeight,
strideWidth,
filterDepth,
filterHeight,
filterWidth,
effectiveFilterDepth,
effectiveFilterHeight,
effectiveFilterWidth,
dilationDepth,
dilationHeight,
dilationWidth,
inShape,
outShape,
filterShape,
};
};
export const createConv3DNaiveProgramInfo = (
inputs: readonly TensorView[],
attributes: ConvAttributes,
outputShape: readonly number[],
filterDims: readonly number[],
pads: readonly number[],
dataFormat: string,
): ProgramInfo => {
const isChannelLast = dataFormat === 'channelsLast';
const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1];
// TODO: enable vec4.
const isVec4 = false;
const workGroupSize: [number, number, number] = [64, 1, 1];
const dispatchLayout = { x: outputShape.map((_, i) => i) };
const dispatch = [Math.ceil(arrayProduct(dispatchLayout.x.map((d) => outputShape[d])) / workGroupSize[0]), 1, 1];
LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`);
const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: filterDims },
{ type: DataType.uint32, data: pads },
{ type: DataType.uint32, data: attributes.strides },
{ type: DataType.uint32, data: attributes.dilations },
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const hasBias = inputs.length === 3;
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');
}
programUniforms.push(...createTensorShapeVariables(outputShape));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'filter_dims', type: 'u32', length: filterDims.length },
{ name: 'pads', type: 'u32', length: pads.length },
{ name: 'strides', type: 'u32', length: attributes.strides.length },
{ name: 'dilations', type: 'u32', length: attributes.dilations.length },
];
appendActivationUniforms(attributes, uniforms);
// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
const x = inputVariable(
'x',
inputs[0].dataType,
inputs[0].dims.length,
innerElementSize === 3 ? 1 : innerElementSize,
);
const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
const inputVariables = [x, w];
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
let declareFunctions = '';
if (hasBias) {
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
declareFunctions += `
fn getBiasByOutputCoords(coords : array<u32, 5>) -> ${isVec4 ? `vec4<${t}>` : t} {
return bias[${isChannelLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${
isVec4 ? '/ 4' : ''
}];
}`;
}
const resType = typeSnippet(innerElementSize, t);
const applyActivation = getActivationSnippet(attributes, resType, t);
return `
${declareFunctions}
fn getX(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> f32 {
let aIndices = array<u32, 5>(d0, d1, d2, d3, d4);
return ${x.getByIndices('aIndices')};
}
fn getW(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> f32 {
let aIndices = array<u32, 5>(d0, d1, d2, d3, d4);
return ${w.getByIndices('aIndices')};
}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let coords = ${output.offsetToIndices('global_idx')};
let batch = ${getElementAt('coords', 0, x.rank)};
let d2 = ${
isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)
};
let xFRCCorner = vec3<u32>(${
isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)
},
${isChannelLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)},
${
isChannelLast ? getElementAt('coords', 3, x.rank) : getElementAt('coords', 4, x.rank)
}) * uniforms.strides - uniforms.pads;
let xFCorner = xFRCCorner.x;
let xRCorner = xFRCCorner.y;
let xCCorner = xFRCCorner.z;
let xShapeY = ${
isChannelLast
? getElementAt('uniforms.x_shape', 1, x.rank)
: getElementAt('uniforms.x_shape', 2, x.rank)
};
let xShapeZ = ${
isChannelLast
? getElementAt('uniforms.x_shape', 2, x.rank)
: getElementAt('uniforms.x_shape', 3, x.rank)
};
let xShapeW = ${
isChannelLast
? getElementAt('uniforms.x_shape', 3, x.rank)
: getElementAt('uniforms.x_shape', 4, x.rank)
};
let xShapeU = ${
isChannelLast
? getElementAt('uniforms.x_shape', 4, x.rank)
: getElementAt('uniforms.x_shape', 1, x.rank)
};
let inputDepthNearestVec4 = (xShapeU / 4) * 4;
let inputDepthVec4Remainder = xShapeU % 4;
var value = 0.0;
for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) {
let xF = xFCorner + wF * uniforms.dilations[0];
if (xF < 0 || xF >= xShapeY) {
continue;
}
for (var wR = 0u; wR < uniforms.filter_dims[1]; wR++) {
let xR = xRCorner + wR * uniforms.dilations[1];
if (xR < 0 || xR >= xShapeZ) {
continue;
}
for (var wC = 0u; wC < uniforms.filter_dims[2]; wC++) {
let xC = xCCorner + wC * uniforms.dilations[2];
if (xC < 0 || xC >= xShapeW) {
continue;
}
for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) {
${
isChannelLast
? `let xValues = vec4<f32>(
getX(batch, xF, xR, xC, d1),
getX(batch, xF, xR, xC, d1 + 1),
getX(batch, xF, xR, xC, d1 + 2),
getX(batch, xF, xR, xC, d1 + 3));
`
: `let xValues = vec4<f32>(
getX(batch, d1, xF, xR, xC),
getX(batch, d1 + 1, xF, xR, xC),
getX(batch, d1 + 2, xF, xR, xC),
getX(batch, d1 + 3, xF, xR, xC));
`
}
let wValues = vec4<f32>(
getW(d2, d1, wF, wR, wC),
getW(d2, d1 + 1, wF, wR, wC),
getW(d2, d1 + 2, wF, wR, wC),
getW(d2, d1 + 3, wF, wR, wC));
value += dot(xValues, wValues);
}
if (inputDepthVec4Remainder == 1) {
${
isChannelLast
? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4)
* getW(d2, inputDepthNearestVec4, wF, wR, wC);`
: `value += getX(batch, inputDepthNearestVec4, xF, xR, xC)
* getW(d2, inputDepthNearestVec4, wF, wR, wC);`
}
} else if (inputDepthVec4Remainder == 2) {
${
isChannelLast
? `let xValues = vec2<f32>(
getX(batch, xF, xR, xC, inputDepthNearestVec4),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1));
`
: `let xValues = vec2<f32>(
getX(batch, inputDepthNearestVec4, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC));
`
}
let wValues = vec2<f32>(
getW(d2, inputDepthNearestVec4, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC));
value += dot(xValues, wValues);
} else if (inputDepthVec4Remainder == 3) {
${
isChannelLast
? `let xValues = vec3<f32>(
getX(batch, xF, xR, xC, inputDepthNearestVec4),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2));
`
: `let xValues = vec3<f32>(
getX(batch, inputDepthNearestVec4, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC));
`
}
let wValues = vec3<f32>(
getW(d2, inputDepthNearestVec4, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC));
value += dot(xValues, wValues);
}
}
}
}
${hasBias ? 'value = value + getBiasByOutputCoords(coords)' : ''};
${applyActivation}
result[global_idx] = f32(value);
}`;
};
return {
name: 'Conv3DNaive',
shaderCache: { hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] },
programUniforms,
}),
getShaderSource,
};
};