onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
212 lines (192 loc) • 7.98 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
import {
atomicOutputVariable,
createTensorShapeVariables,
inputVariable,
outputVariable,
ShaderHelper,
} from './common';
export interface ScatterNDAttributes extends AttributeWithCacheKey {
reduction: string;
}
type ReductionType = 'i32' | 'u32' | 'f32';
const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: ReductionType) => {
if (reduction !== 'none' && type !== 'i32' && type !== 'u32' && type !== 'f32') {
throw new Error(`Input ${type} is not supported with reduction ${reduction}.`);
}
const floatStart = `{
var oldValue = 0;
loop {
let newValueF32 =`;
const floatEnd = `;
let newValue = bitcast<i32>(newValueF32);
let res = atomicCompareExchangeWeak(&${ptr}, oldValue, newValue);
if res.exchanged {
break;
}
oldValue = res.old_value;
}
}`;
switch (reduction) {
case 'none':
return `${ptr}=${v};`;
case 'add':
if (type === 'i32' || type === 'u32') {
return `atomicAdd(&${ptr}, bitcast<${type}>(${v}));`;
} else {
// atomicAdd only supports uint/int type. For float, we use
// atomicCompareExchangeWeak to simulate.
return `
${floatStart}bitcast<${type}>(oldValue) + (${v})${floatEnd}`;
}
case 'max':
if (type === 'i32' || type === 'u32') {
return `atomicMax(&${ptr}, bitcast<${type}>(${v}));`;
} else {
// atomicMax only supports uint/int type. For float, we use
// atomicCompareExchangeWeak to simulate.
return `
${floatStart}max(bitcast<f32>(oldValue), (${v}))${floatEnd}`;
}
case 'min':
if (type === 'i32' || type === 'u32') {
return `atomicMin(&${ptr}, bitcast<${type}>(${v}));`;
} else {
// atomicMin only supports uint/int type. For float, we use
// atomicCompareExchangeWeak to simulate.
return `${floatStart}min(bitcast<${type}>(oldValue), (${v}))${floatEnd}`;
}
case 'mul':
// atomicMul is not supported, we use atomicCompareExchangeWeak to simulate.
return `${floatStart}(bitcast<${type}>(oldValue) * (${v}))${floatEnd}`;
default:
throw new Error(`Reduction ${reduction} is not supported.`);
}
};
const calcDataOffsetSnippet = (dataRank: number, parallel: boolean) =>
`${
dataRank === 1
? `
let element_count_dim = uniforms.output_strides;
let dim_value = uniforms.output_shape;`
: `
let element_count_dim = uniforms.output_strides[${parallel ? 'i - indices_start' : 'i'}];
let dim_value = uniforms.output_shape[${parallel ? 'i - indices_start' : 'i'} + uniforms.last_index_dimension];`
}
if (index >= 0) {
if (index >= i32(dim_value)) {
index = i32(dim_value - 1);
}
} else {
if (index < -i32(dim_value)) {
index = 0;
} else {
index += i32(dim_value);
}
}
data_offset += u32((u32(index) * element_count_dim));`;
const updateElementsSnippet = (attributes: ScatterNDAttributes, outputTypeValue: ReductionType, parallel: boolean) =>
`for (var i = 0u; i < uniforms.num_updates_elements; i++) {
let value = updates[uniforms.num_updates_elements * ${parallel ? 'global_idx' : 'idx'} + i];
${atomicReductionSnippet(attributes.reduction, 'output[data_offset + i]', 'value', outputTypeValue)}
}`;
const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const indicesShape = inputs[1].dims;
const outputShape = inputShape;
// TODO: support bool with components 4.
const components = 1;
const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components);
const lastIndexDimension = indicesShape[indicesShape.length - 1];
const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension);
const numIndicesElements = ShapeUtil.sizeFromDimension(indicesShape, 0) / lastIndexDimension;
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: lastIndexDimension },
{ type: DataType.uint32, data: numUpdatesElements },
...createTensorShapeVariables(inputs[1].dims, inputs[2].dims, outputShape),
];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const indices = inputVariable('indices', inputs[1].dataType, inputs[1].dims.length);
const updates = inputVariable('updates', inputs[2].dataType, inputs[2].dims.length, components);
const output =
attributes.reduction !== 'none' && attributes.reduction !== ''
? atomicOutputVariable('output', inputs[0].dataType, outputShape.length)
: outputVariable('output', inputs[0].dataType, outputShape.length, components);
return `
${shaderHelper
.registerUniform('output_size', 'u32')
.registerUniform('last_index_dimension', 'u32')
.registerUniform('num_updates_elements', 'u32')
.declareVariables(indices, updates, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var hasDuplicates = false;
if (${attributes.reduction === 'none'}) {
for (var i = 0; i < ${numIndicesElements}; i = i + 1) {
for (var j = i + 1; j < ${numIndicesElements}; j = j + 1) {
var index_i = i32(indices[i].x);
var index_j = i32(indices[j].x);
if (index_i == index_j) {
hasDuplicates = true;
break;
}
}
if (hasDuplicates) {
break;
}
}
}
if (${attributes.reduction === 'none'} && hasDuplicates) {
if (global_idx != 0u) {
return;
}
// Process each index-update pair individually when duplicates exist
for (var idx = 0u; idx < ${numIndicesElements}u; idx++) {
var data_offset = 0u;
for (var i = 0u; i < uniforms.last_index_dimension; i++) {
var index = i32(indices[idx * uniforms.last_index_dimension + i].x);
${calcDataOffsetSnippet(inputShape.length, false)}
}
${updateElementsSnippet(attributes, output.type.value as ReductionType, false)}
}
return;
}
var data_offset = 0u;
var indices_start = uniforms.last_index_dimension * global_idx;
var indices_end = indices_start + uniforms.last_index_dimension;
for (var i = indices_start; i < indices_end; i++) {
var index = i32(indices[i].x);
${calcDataOffsetSnippet(inputShape.length, true)}
}
${updateElementsSnippet(attributes, output.type.value as ReductionType, true)}
}`;
};
return {
name: 'ScatterND',
shaderCache: {
hint: `${attributes.cacheKey}_${attributes.reduction}`,
inputDependencies: ['rank', 'rank'],
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
};
};
export const parseScatterNDAttributes = (attributes: Record<string, unknown>): ScatterNDAttributes =>
createAttributeWithCacheKey({ reduction: attributes.reduction as string });
export const scatterND = (context: ComputeContext, attributes: ScatterNDAttributes): void => {
context.compute(createScatterNDProgramInfo(context.inputs, attributes), {
inputs: [context.inputs[1], context.inputs[2]],
outputs: [],
});
};