UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

209 lines (180 loc) 7.12 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; import { Graph } from '../../../graph'; import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators'; import { Tensor } from '../../../tensor'; import { ShapeUtil } from '../../../util'; import { WebGLInferenceHandler } from '../inference-handler'; import { ProgramInfo, ProgramMetadata, TextureType } from '../types'; export interface ReduceAttributes extends AttributeWithCacheKey { readonly axes: number[]; readonly keepDims: boolean; } // return [init ops, reduce ops, final ops] type ReduceOp = (inputs: Tensor[], axes: number[]) => string[]; const reduce = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, name: string, reduceOp: ReduceOp, ): Tensor[] => { validateInputs(inputs); const reduceProgramMetadata = { name, inputNames: ['A'], inputTypes: [TextureType.unpacked], }; const output = inferenceHandler.run( { ...reduceProgramMetadata, cacheHint: attributes.cacheKey, get: () => createReduceProgramInfo(inferenceHandler, inputs, attributes, name, reduceOp, reduceProgramMetadata), }, inputs, ); return [output]; }; export const parseReduceAttributes: OperatorInitialization<ReduceAttributes> = (node: Graph.Node): ReduceAttributes => { const axes = node.attributes.getInts('axes', []); const keepDims = node.attributes.getInt('keepdims', 1) === 1; return createAttributeWithCacheKey({ axes, keepDims }); }; const createReduceProgramInfo = ( _handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, _name: string, reduceOp: ReduceOp, reduceProgramMetadata: ProgramMetadata, ): ProgramInfo => { const outputShape: number[] = []; const iRank = inputs[0].dims.length || 1; const idxCopy = []; // copy output indexes to input indexes const axes = ShapeUtil.normalizeAxes(attributes.axes, inputs[0].dims.length); const ops = reduceOp(inputs, axes); let reduceOps = ops[1]; for (let k = 0; k < inputs[0].dims.length; k++) { // if this axis is reduced if (axes.indexOf(k) >= 0 || axes.length === 0) { if (attributes.keepDims) { outputShape.push(1); } // else { remove the axis from outputShape; } // loop over the d-th axis reduceOps = ` for(int j${k} = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) { inputIdx[${k}] = j${k}; ${reduceOps} }`; } else { idxCopy.push(`inputIdx[${k}] = outputIdx[${outputShape.length}];`); outputShape.push(inputs[0].dims[k]); } } const oRank = outputShape.length || 1; const shaderSource = ` float process(int outputIdx[${oRank}]) { float value; // final result int inputIdx[${iRank}]; // addressing input data ${idxCopy.join('\n')} ${ops[0]} // init ops for reduce max/min ${reduceOps} ${ops[2]} // final computation for reduce mean return value; }`; return { ...reduceProgramMetadata, output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, shaderSource, }; }; const validateInputs = (inputs: Tensor[]): void => { // TODO: support Reduce* operators with 2 inputs. if (!inputs || inputs.length !== 1) { throw new Error('Reduce op requires 1 input.'); } if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) { throw new Error('Invalid input type.'); } }; export const reduceSum: OperatorImplementation<ReduceAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, ): Tensor[] => { const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', '']; return reduce(inferenceHandler, inputs, attributes, 'ReduceSum', reduceOp); }; export const reduceMean: OperatorImplementation<ReduceAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, ): Tensor[] => { const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { let size = 1.0; for (let k = 0; k < inputs[0].dims.length; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { size *= inputs[0].dims[k]; } } return ['value = 0.0;', 'value += _A(inputIdx);', `value /= ${size}.;`]; // ensure real number with `.` }; return reduce(inferenceHandler, inputs, attributes, 'ReduceMean', reduceOp); }; export const reduceMax: OperatorImplementation<ReduceAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, ): Tensor[] => { const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { const idxZero = []; for (let k = 0; k < inputs[0].dims.length; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { idxZero.push(`inputIdx[${k}] = 0;`); // first element } } return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = max(value, _A(inputIdx));', '']; }; return reduce(inferenceHandler, inputs, attributes, 'ReduceMax', reduceOp); }; export const reduceMin: OperatorImplementation<ReduceAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, ): Tensor[] => { const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { const idxZero = []; for (let k = 0; k < inputs[0].dims.length; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { idxZero.push(`inputIdx[${k}] = 0;`); // first element } } return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = min(value, _A(inputIdx));', '']; }; return reduce(inferenceHandler, inputs, attributes, 'ReduceMin', reduceOp); }; export const reduceProd: OperatorImplementation<ReduceAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, ): Tensor[] => { const reduceOp: ReduceOp = (): string[] => ['value = 1.0;', 'value *= _A(inputIdx);', '']; return reduce(inferenceHandler, inputs, attributes, 'ReduceProd', reduceOp); }; export const reduceLogSum: OperatorImplementation<ReduceAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, ): Tensor[] => { const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', 'value = log(value);']; return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSum', reduceOp); }; export const reduceLogSumSquare: OperatorImplementation<ReduceAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, ): Tensor[] => { const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', '']; return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp); };