onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
209 lines (180 loc) • 7.12 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
import { Graph } from '../../../graph';
import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators';
import { Tensor } from '../../../tensor';
import { ShapeUtil } from '../../../util';
import { WebGLInferenceHandler } from '../inference-handler';
import { ProgramInfo, ProgramMetadata, TextureType } from '../types';
export interface ReduceAttributes extends AttributeWithCacheKey {
readonly axes: number[];
readonly keepDims: boolean;
}
// return [init ops, reduce ops, final ops]
type ReduceOp = (inputs: Tensor[], axes: number[]) => string[];
const reduce = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
name: string,
reduceOp: ReduceOp,
): Tensor[] => {
validateInputs(inputs);
const reduceProgramMetadata = {
name,
inputNames: ['A'],
inputTypes: [TextureType.unpacked],
};
const output = inferenceHandler.run(
{
...reduceProgramMetadata,
cacheHint: attributes.cacheKey,
get: () => createReduceProgramInfo(inferenceHandler, inputs, attributes, name, reduceOp, reduceProgramMetadata),
},
inputs,
);
return [output];
};
export const parseReduceAttributes: OperatorInitialization<ReduceAttributes> = (node: Graph.Node): ReduceAttributes => {
const axes = node.attributes.getInts('axes', []);
const keepDims = node.attributes.getInt('keepdims', 1) === 1;
return createAttributeWithCacheKey({ axes, keepDims });
};
const createReduceProgramInfo = (
_handler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
_name: string,
reduceOp: ReduceOp,
reduceProgramMetadata: ProgramMetadata,
): ProgramInfo => {
const outputShape: number[] = [];
const iRank = inputs[0].dims.length || 1;
const idxCopy = []; // copy output indexes to input indexes
const axes = ShapeUtil.normalizeAxes(attributes.axes, inputs[0].dims.length);
const ops = reduceOp(inputs, axes);
let reduceOps = ops[1];
for (let k = 0; k < inputs[0].dims.length; k++) {
// if this axis is reduced
if (axes.indexOf(k) >= 0 || axes.length === 0) {
if (attributes.keepDims) {
outputShape.push(1);
} // else { remove the axis from outputShape; }
// loop over the d-th axis
reduceOps = `
for(int j${k} = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
inputIdx[${k}] = j${k};
${reduceOps}
}`;
} else {
idxCopy.push(`inputIdx[${k}] = outputIdx[${outputShape.length}];`);
outputShape.push(inputs[0].dims[k]);
}
}
const oRank = outputShape.length || 1;
const shaderSource = `
float process(int outputIdx[${oRank}]) {
float value; // final result
int inputIdx[${iRank}]; // addressing input data
${idxCopy.join('\n')}
${ops[0]} // init ops for reduce max/min
${reduceOps}
${ops[2]} // final computation for reduce mean
return value;
}`;
return {
...reduceProgramMetadata,
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
shaderSource,
};
};
const validateInputs = (inputs: Tensor[]): void => {
// TODO: support Reduce* operators with 2 inputs.
if (!inputs || inputs.length !== 1) {
throw new Error('Reduce op requires 1 input.');
}
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
throw new Error('Invalid input type.');
}
};
export const reduceSum: OperatorImplementation<ReduceAttributes> = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
): Tensor[] => {
const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', ''];
return reduce(inferenceHandler, inputs, attributes, 'ReduceSum', reduceOp);
};
export const reduceMean: OperatorImplementation<ReduceAttributes> = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
): Tensor[] => {
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
let size = 1.0;
for (let k = 0; k < inputs[0].dims.length; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
size *= inputs[0].dims[k];
}
}
return ['value = 0.0;', 'value += _A(inputIdx);', `value /= ${size}.;`]; // ensure real number with `.`
};
return reduce(inferenceHandler, inputs, attributes, 'ReduceMean', reduceOp);
};
export const reduceMax: OperatorImplementation<ReduceAttributes> = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
): Tensor[] => {
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
const idxZero = [];
for (let k = 0; k < inputs[0].dims.length; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`inputIdx[${k}] = 0;`); // first element
}
}
return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = max(value, _A(inputIdx));', ''];
};
return reduce(inferenceHandler, inputs, attributes, 'ReduceMax', reduceOp);
};
export const reduceMin: OperatorImplementation<ReduceAttributes> = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
): Tensor[] => {
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
const idxZero = [];
for (let k = 0; k < inputs[0].dims.length; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`inputIdx[${k}] = 0;`); // first element
}
}
return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = min(value, _A(inputIdx));', ''];
};
return reduce(inferenceHandler, inputs, attributes, 'ReduceMin', reduceOp);
};
export const reduceProd: OperatorImplementation<ReduceAttributes> = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
): Tensor[] => {
const reduceOp: ReduceOp = (): string[] => ['value = 1.0;', 'value *= _A(inputIdx);', ''];
return reduce(inferenceHandler, inputs, attributes, 'ReduceProd', reduceOp);
};
export const reduceLogSum: OperatorImplementation<ReduceAttributes> = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
): Tensor[] => {
const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', 'value = log(value);'];
return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSum', reduceOp);
};
export const reduceLogSumSquare: OperatorImplementation<ReduceAttributes> = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: ReduceAttributes,
): Tensor[] => {
const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', ''];
return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp);
};