UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

180 lines (162 loc) 7.1 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey } from '../attribute-with-cache-key'; import { ComputeContext, ProgramUniform } from '../types'; import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; export interface GatherNDAttributes extends AttributeWithCacheKey { readonly batchDims: number; } const computeSliceOffsets = ( context: ComputeContext, indicesData: TensorView, sizesFromSliceDimsData: number[], batchDims: number, inputDims: readonly number[], numSlices: number, numSlicesPerBatch: number, inputBatchStride: number, numSliceDims: number, ) => { const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: numSlices }, { type: DataType.uint32, data: batchDims }, { type: DataType.uint32, data: inputDims }, { type: DataType.uint32, data: sizesFromSliceDimsData }, { type: DataType.uint32, data: numSlicesPerBatch }, { type: DataType.uint32, data: inputBatchStride }, { type: DataType.uint32, data: numSliceDims }, ]; const outputShape = [numSlices]; programUniforms.push(...createTensorShapeVariables(indicesData.dims, outputShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { const indices = inputVariable('indices_data', indicesData.dataType, indicesData.dims.length); const output = outputVariable('input_slice_offsets_data', DataType.uint32, 1, 1); const variables = [indices, output]; const uniforms: UniformsArrayType = [ { name: 'output_size', type: 'u32' }, { name: 'batch_dims', type: 'u32' }, { name: 'input_dims', type: 'u32', length: inputDims.length }, { name: 'sizes_from_slice_dims_data', type: 'u32', length: sizesFromSliceDimsData.length }, { name: 'num_slices_per_batch', type: 'u32' }, { name: 'input_batch_stride', type: 'u32' }, { name: 'num_slice_dims', type: 'u32' }, ]; return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let batch_idx = global_idx / uniforms.num_slices_per_batch; let base_offset = batch_idx * uniforms.input_batch_stride; let slice_indices_base_offset = global_idx * uniforms.num_slice_dims; var relative_slice_offset = 0; for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) { var index = i32(indices_data[dim_idx + slice_indices_base_offset].x); let input_dim_idx = uniforms.batch_dims + dim_idx; if (index < 0) { ${ inputDims.length === 1 ? 'index += i32(uniforms.input_dims);' : 'index += i32(uniforms.input_dims[input_dim_idx]);' } } ${ sizesFromSliceDimsData.length === 1 ? 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);' : 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);' } } input_slice_offsets_data[global_idx] = base_offset + u32(relative_slice_offset); }`; }; return context.compute( { name: 'computeSliceOffsets', shaderCache: { hint: `${inputDims.length}_${sizesFromSliceDimsData.length}`, inputDependencies: ['rank'] }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: context.inputs[1].dataType }], dispatchGroup: { x: Math.ceil(numSlices / 64) }, programUniforms, }), getShaderSource, }, { inputs: [indicesData], outputs: [-1] }, )[0]; }; export const gatherND = (context: ComputeContext, attributes: GatherNDAttributes) => { const inputs = context.inputs; const inputShape = inputs[0].dims; const inputType = inputs[0].dataType; const indicesShape = inputs[1].dims; const numSliceDims = indicesShape[indicesShape.length - 1]; const numSlices = ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1); const sliceSize = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims + numSliceDims); const numBatches = ShapeUtil.sizeToDimension(inputShape, attributes.batchDims); const inputBatchStride = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims); const numSlicesPerBatch = numSlices / numBatches; const sizesFromSliceDims = new Array(numSliceDims); let runningProduct = sliceSize; for (let i = 0; i < numSliceDims; ++i) { sizesFromSliceDims[numSliceDims - 1 - i] = runningProduct; runningProduct *= inputShape[attributes.batchDims + numSliceDims - 1 - i]; } const inputSliceOffsets = computeSliceOffsets( context, inputs[1], sizesFromSliceDims, attributes.batchDims, inputShape, numSlices, numSlicesPerBatch, inputBatchStride, numSliceDims, ); const lastIndicesDimension = attributes.batchDims + numSliceDims; if (lastIndicesDimension > inputShape.length) { throw new Error('last dimension of indices must not be larger than rank of input tensor'); } const outputShape = indicesShape.slice(0, -1).concat(inputShape.slice(lastIndicesDimension)); const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, { type: DataType.uint32, data: sliceSize }, ...createTensorShapeVariables(inputs[0].dims, inputSliceOffsets.dims, outputShape), ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const input = inputVariable('data', inputs[0].dataType, inputs[0].dims.length); const indices = inputVariable('slice_offsets', DataType.uint32, inputSliceOffsets.dims.length); const output = outputVariable('output', inputs[0].dataType, outputShape.length); return ` ${shaderHelper .registerUniform('output_size', 'u32') .registerUniform('slice_size', 'u32') .declareVariables(input, indices, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let slice_offset = slice_offsets[global_idx / uniforms.slice_size]; output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size]; }`; }; context.compute( { name: 'GatherND', shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank', 'rank'] }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: inputType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms, }), getShaderSource, }, { inputs: [inputs[0], inputSliceOffsets] }, ); }; export const parseGatherNDAttributes = (attributes: Record<string, unknown>): GatherNDAttributes => { const batchDims = attributes.batch_dims as number; return { batchDims, cacheKey: '', }; };