UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

111 lines (99 loc) 3.42 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; import { Graph } from '../../../graph'; import { OperatorImplementation, OperatorInitialization } from '../../../operators'; import { Tensor } from '../../../tensor'; import { ShapeUtil, SplitUtil } from '../../../util'; import { WebGLInferenceHandler } from '../inference-handler'; import { ProgramInfo, TextureType } from '../types'; export interface SplitAttributes extends AttributeWithCacheKey { readonly axis: number; readonly split: number[]; readonly numOutputs: number; } const splitProgramMetadata = { name: 'Split', inputNames: ['A'], inputTypes: [TextureType.unpacked], }; export const split: OperatorImplementation<SplitAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SplitAttributes, ): Tensor[] => { validateInputs(inputs); const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length); const count = getProgramCount(inferenceHandler, inputs, axis, attributes); const output: Tensor[] = []; for (let i = 0; i < count; ++i) { output.push( inferenceHandler.run( { ...splitProgramMetadata, cacheHint: `${attributes.cacheKey};${i}`, get: () => createSplitProgramInfo(inferenceHandler, inputs[0], attributes, axis, i), }, inputs, ), ); } return output; }; export const parseSplitAttributes: OperatorInitialization<SplitAttributes> = (node: Graph.Node): SplitAttributes => { const axis = node.attributes.getInt('axis', 0); const split = node.attributes.getInts('split', []); const numOutputs = node.outputs.length; return createAttributeWithCacheKey({ axis, split, numOutputs }); }; const getProgramCount = ( _inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number, attributes: SplitAttributes, ): number => { const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs); return offsets.length; }; const createSplitProgramInfo = ( _inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SplitAttributes, axis: number, index: number, ): ProgramInfo => { const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs); const offset = offsets[index]; const outputShape = shapes[index]; const rank = outputShape.length; const shaderSource = ` float process(int indices[${rank}]) { indices[${axis}] += ${offset}; return _A(indices); } `; return { ...splitProgramMetadata, cacheHint: `${attributes.cacheKey}:${index}`, output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked }, shaderSource, }; }; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Split requires one input.'); } if ( inputs[0].type !== 'int8' && inputs[0].type !== 'uint8' && inputs[0].type !== 'int16' && inputs[0].type !== 'uint16' && inputs[0].type !== 'int32' && inputs[0].type !== 'uint32' && inputs[0].type !== 'float32' && inputs[0].type !== 'float64' && inputs[0].type !== 'bool' ) { throw new Error('Invalid input type.'); } };