UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

117 lines (105 loc) 4.9 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { BroadcastUtil, ShapeUtil } from '../../util'; import { ComputeContext, ProgramInfo } from '../types'; import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common'; const createWhereOpProgramShader = ( shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, typeOutput: number, ) => { const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4); const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4); const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4); const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4); let assignment: string; const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; if (!isBroadcast) { assignment = output.setByOffset( 'global_idx', expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')), ); } else { const singleAssignment = (resStr: string, x: number, typeCast = '') => { const expressionA = `a_data[index_a${x}][component_a${x}]`; const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; return ` let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)}; let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)}; let index_a${x} = offset_a${x} / 4u; let index_b${x} = offset_b${x} / 4u; let index_c${x} = offset_c${x} / 4u; let component_a${x} = offset_a${x} % 4u; let component_b${x} = offset_b${x} % 4u; let component_c${x} = offset_c${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; if (typeOutput === DataType.bool) { assignment = ` var data = vec4<u32>(0); ${singleAssignment('data', 0, 'u32')} ${singleAssignment('data', 1, 'u32')} ${singleAssignment('data', 2, 'u32')} ${singleAssignment('data', 3, 'u32')} output_data[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`; } else { assignment = ` ${singleAssignment('output_data[global_idx]', 0)} ${singleAssignment('output_data[global_idx]', 1)} ${singleAssignment('output_data[global_idx]', 2)} ${singleAssignment('output_data[global_idx]', 3)} `; } } return ` ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} ${assignment} }`; }; const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { const dimsA = inputs[1].dims; const dimsB = inputs[2].dims; const dimsC = inputs[0].dims; const outputDataType = inputs[1].dataType; const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); let outputShape = dimsA; let outputSize = ShapeUtil.size(dimsA); // TODO: deal with zero-sized tensors (eg. dims=[1,0]) if (isBroadcast) { const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false); if (!calculatedShape) { throw new Error("Can't perform where op on the given tensors"); } outputShape = calculatedShape; outputSize = ShapeUtil.size(outputShape); } const vecSize = Math.ceil(outputSize / 4); return { name: 'Where', shaderCache: { inputDependencies: ['rank', 'rank', 'rank'] }, getShaderSource: (shaderHelper) => createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), getRunData: () => ({ outputs: [{ dims: outputShape, dataType: outputDataType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */) }, programUniforms: [ { type: DataType.uint32, data: vecSize }, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape), ], }), }; }; export const where = (context: ComputeContext): void => { context.compute(createWhereOpProgramInfo(context.inputs)); };