onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
80 lines (74 loc) • 3.66 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo } from '../types';
import { createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper } from './common';
export interface CumSumAttributes extends AttributeWithCacheKey {
readonly exclusive: boolean;
readonly reverse: boolean;
}
const createCumsumProgramInfo = (
inputType: number,
inputShape: readonly number[],
axisInput: TensorView,
attributes: CumSumAttributes,
): ProgramInfo => {
const outputSize = ShapeUtil.size(inputShape); // outputShape is same as inputShape.
const rank = inputShape.length; // input/output rank
const input = inputVariable('input', inputType, rank);
const output = outputVariable('output', inputType, rank);
const axisValue =
axisInput.dataType === DataType.int32 ? axisInput.getInt32Array()[0] : Number(axisInput.getBigInt64Array()[0]);
const axis = ShapeUtil.normalizeAxis(axisValue, rank);
const getShaderSource = (shaderHelper: ShaderHelper) => {
const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `;
const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank);
const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0';
const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1');
return `
${shaderHelper
.registerUniform('outputSize', 'u32')
.registerUniform('axis', 'u32')
.declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var inputIndices = ${output.offsetToIndices('global_idx')};
var sum = ${output.type.value}(0);
let first : i32 = ${lowerLimit};
let last : i32 = ${upperLimit};
for (var i : i32 = first; i < last; i++) {
${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(i)')};
sum = sum + ${input.getByIndices('inputIndices')};
}
${output.setByOffset('global_idx', 'sum')};
}`;
};
return {
name: 'CumSum',
shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank'] },
getRunData: () => ({
outputs: [{ dims: inputShape, dataType: inputType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms: [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: axis },
...createTensorShapeVariables(inputShape, inputShape),
],
}),
getShaderSource,
};
};
export const cumsum = (context: ComputeContext, attributes: CumSumAttributes): void => {
const inputShape = context.inputs[0].dims;
const inputType = context.inputs[0].dataType;
const axis = context.inputs[1];
context.compute(createCumsumProgramInfo(inputType, inputShape, axis, attributes), { inputs: [0] });
};
export const parseCumSumAttributes = (attributes: Record<string, unknown>): CumSumAttributes => {
const exclusive = (attributes.exclusive as number) === 1;
const reverse = (attributes.reverse as number) === 1;
return createAttributeWithCacheKey({ exclusive, reverse });
};