onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
164 lines (143 loc) • 6.16 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common';
export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
}
const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}
const referenceIndex = 0;
const referenceInput = inputs[referenceIndex];
const inputType = referenceInput.dataType;
const inputRank = referenceInput.dims.length;
inputs.forEach((input, i) => {
if (i === referenceIndex) {
return;
}
// make sure types of all inputs match
if (input.dataType !== inputType) {
throw new Error('input tensors should be one type');
}
// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputRank) {
throw new Error('input tensors should have the same shape');
}
input.dims.forEach((dim, i) => {
if (i !== axis && dim !== referenceInput.dims[i]) {
throw new Error('non concat dimensions must match');
}
});
});
};
const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
fn calculateInputIndex(index: u32) -> u32 {
let sizeInConcatAxis = array<u32, ${numberOfTensors}u>(${sizeInConcatAxisStr});
for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) {
if (index < sizeInConcatAxis[i]) {
return i;
}
}
return ${numberOfTensors}u;
}`;
const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelper) => {
const numberOfTensors = inputs.length;
const codeLines: string[] = [];
for (let i = 0; i < numberOfTensors; ++i) {
const returnSnippet = output.setByOffset('global_idx', inputs[i].getByIndices('indices'));
if (numberOfTensors === 1) {
codeLines.push(returnSnippet);
} else if (i === 0) {
codeLines.push(`if (inputIndex == ${i}u) { ${returnSnippet} }`);
} else if (i === numberOfTensors - 1) {
codeLines.push(`else { ${returnSnippet} }`);
} else {
codeLines.push(`else if (inputIndex == ${i}) { ${returnSnippet} }`);
}
}
return codeLines.join('\n');
};
const createConcatProgramInfo = (
inputs: readonly TensorView[],
adjustedAxis: number,
outputShape: number[],
dataType: DataType,
): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);
const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);
let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{ type: DataType.uint32, data: outputSize }];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({ type: DataType.uint32, data: sizeInConcatAxis[i] });
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr = Array.from(Array(sizeInConcatAxis.length).keys())
.map((i) => `uniforms.sizeInConcatAxis${i}`)
.join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
${(() => {
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}
${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var indices = ${output.offsetToIndices('global_idx')};
let inputIndex = calculateInputIndex(${indicesAxis});
if (inputIndex != 0u) {
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
}
${assignOutputData(inputVars, output)}
}`;
return {
name: 'Concat',
shaderCache: { hint: `${adjustedAxis}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
};
};
export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
const inputs = context.inputs;
const inputShape = inputs[0].dims;
const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
validateInputs(inputs, adjustedAxis);
const outputShape = inputShape.slice();
outputShape[adjustedAxis] = inputs.reduce(
(sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0),
0,
);
// 0 length tensors are valid for concat, remove them
const nonEmptyInputs = inputs.filter((input) => ShapeUtil.size(input.dims) > 0);
context.compute(createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {
inputs: nonEmptyInputs,
});
};
export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
createAttributeWithCacheKey({ axis: attributes.axis as number });