UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

80 lines (74 loc) 3.66 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; import { ComputeContext, ProgramInfo } from '../types'; import { createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper } from './common'; export interface CumSumAttributes extends AttributeWithCacheKey { readonly exclusive: boolean; readonly reverse: boolean; } const createCumsumProgramInfo = ( inputType: number, inputShape: readonly number[], axisInput: TensorView, attributes: CumSumAttributes, ): ProgramInfo => { const outputSize = ShapeUtil.size(inputShape); // outputShape is same as inputShape. const rank = inputShape.length; // input/output rank const input = inputVariable('input', inputType, rank); const output = outputVariable('output', inputType, rank); const axisValue = axisInput.dataType === DataType.int32 ? axisInput.getInt32Array()[0] : Number(axisInput.getBigInt64Array()[0]); const axis = ShapeUtil.normalizeAxis(axisValue, rank); const getShaderSource = (shaderHelper: ShaderHelper) => { const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `; const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank); const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0'; const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1'); return ` ${shaderHelper .registerUniform('outputSize', 'u32') .registerUniform('axis', 'u32') .declareVariables(input, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var inputIndices = ${output.offsetToIndices('global_idx')}; var sum = ${output.type.value}(0); let first : i32 = ${lowerLimit}; let last : i32 = ${upperLimit}; for (var i : i32 = first; i < last; i++) { ${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(i)')}; sum = sum + ${input.getByIndices('inputIndices')}; } ${output.setByOffset('global_idx', 'sum')}; }`; }; return { name: 'CumSum', shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank'] }, getRunData: () => ({ outputs: [{ dims: inputShape, dataType: inputType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms: [ { type: DataType.uint32, data: outputSize }, { type: DataType.uint32, data: axis }, ...createTensorShapeVariables(inputShape, inputShape), ], }), getShaderSource, }; }; export const cumsum = (context: ComputeContext, attributes: CumSumAttributes): void => { const inputShape = context.inputs[0].dims; const inputType = context.inputs[0].dataType; const axis = context.inputs[1]; context.compute(createCumsumProgramInfo(inputType, inputShape, axis, attributes), { inputs: [0] }); }; export const parseCumSumAttributes = (attributes: Record<string, unknown>): CumSumAttributes => { const exclusive = (attributes.exclusive as number) === 1; const reverse = (attributes.reverse as number) === 1; return createAttributeWithCacheKey({ exclusive, reverse }); };