UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

164 lines (143 loc) 6.16 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } const validateInputs = (inputs: readonly TensorView[], axis: number): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } const referenceIndex = 0; const referenceInput = inputs[referenceIndex]; const inputType = referenceInput.dataType; const inputRank = referenceInput.dims.length; inputs.forEach((input, i) => { if (i === referenceIndex) { return; } // make sure types of all inputs match if (input.dataType !== inputType) { throw new Error('input tensors should be one type'); } // make sure the dimensionality of all inputs are the same if (input.dims.length !== inputRank) { throw new Error('input tensors should have the same shape'); } input.dims.forEach((dim, i) => { if (i !== axis && dim !== referenceInput.dims[i]) { throw new Error('non concat dimensions must match'); } }); }); }; const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` fn calculateInputIndex(index: u32) -> u32 { let sizeInConcatAxis = array<u32, ${numberOfTensors}u>(${sizeInConcatAxisStr}); for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) { if (index < sizeInConcatAxis[i]) { return i; } } return ${numberOfTensors}u; }`; const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelper) => { const numberOfTensors = inputs.length; const codeLines: string[] = []; for (let i = 0; i < numberOfTensors; ++i) { const returnSnippet = output.setByOffset('global_idx', inputs[i].getByIndices('indices')); if (numberOfTensors === 1) { codeLines.push(returnSnippet); } else if (i === 0) { codeLines.push(`if (inputIndex == ${i}u) { ${returnSnippet} }`); } else if (i === numberOfTensors - 1) { codeLines.push(`else { ${returnSnippet} }`); } else { codeLines.push(`else if (inputIndex == ${i}) { ${returnSnippet} }`); } } return codeLines.join('\n'); }; const createConcatProgramInfo = ( inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType, ): ProgramInfo => { const outputSize = ShapeUtil.size(outputShape); const sizeInConcatAxis = new Array<number>(inputs.length); const inputVars = new Array<IndicesHelper>(inputs.length); let previousSum = 0; const inputDependencies: ProgramInputTensorInfoDependency[] = []; const inputRanks = []; const programUniforms: ProgramUniform[] = [{ type: DataType.uint32, data: outputSize }]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; inputRanks.push(inputs[i].dims.length); inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); inputDependencies.push('rank'); programUniforms.push({ type: DataType.uint32, data: sizeInConcatAxis[i] }); } for (let i = 0; i < inputs.length; ++i) { programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); } programUniforms.push(...createTensorShapeVariables(outputShape)); const output = outputVariable('output', dataType, outputShape.length); const indicesAxis = output.indicesGet('indices', adjustedAxis); const sizeInConcatAxisStr = Array.from(Array(sizeInConcatAxis.length).keys()) .map((i) => `uniforms.sizeInConcatAxis${i}`) .join(','); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${(() => { shaderHelper.registerUniform('outputSize', 'u32'); for (let i = 0; i < inputs.length; i++) { shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); } return shaderHelper.declareVariables(...inputVars, output); })()} ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var indices = ${output.offsetToIndices('global_idx')}; let inputIndex = calculateInputIndex(${indicesAxis}); if (inputIndex != 0u) { let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr}); ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; } ${assignOutputData(inputVars, output)} }`; return { name: 'Concat', shaderCache: { hint: `${adjustedAxis}`, inputDependencies }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms, }), getShaderSource, }; }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { const inputs = context.inputs; const inputShape = inputs[0].dims; const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); validateInputs(inputs, adjustedAxis); const outputShape = inputShape.slice(); outputShape[adjustedAxis] = inputs.reduce( (sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0, ); // 0 length tensors are valid for concat, remove them const nonEmptyInputs = inputs.filter((input) => ShapeUtil.size(input.dims) > 0); context.compute(createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), { inputs: nonEmptyInputs, }); }; export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes => createAttributeWithCacheKey({ axis: attributes.axis as number });