onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
111 lines (99 loc) • 3.42 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
import { Graph } from '../../../graph';
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
import { Tensor } from '../../../tensor';
import { ShapeUtil, SplitUtil } from '../../../util';
import { WebGLInferenceHandler } from '../inference-handler';
import { ProgramInfo, TextureType } from '../types';
export interface SplitAttributes extends AttributeWithCacheKey {
readonly axis: number;
readonly split: number[];
readonly numOutputs: number;
}
const splitProgramMetadata = {
name: 'Split',
inputNames: ['A'],
inputTypes: [TextureType.unpacked],
};
export const split: OperatorImplementation<SplitAttributes> = (
inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
attributes: SplitAttributes,
): Tensor[] => {
validateInputs(inputs);
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
const count = getProgramCount(inferenceHandler, inputs, axis, attributes);
const output: Tensor[] = [];
for (let i = 0; i < count; ++i) {
output.push(
inferenceHandler.run(
{
...splitProgramMetadata,
cacheHint: `${attributes.cacheKey};${i}`,
get: () => createSplitProgramInfo(inferenceHandler, inputs[0], attributes, axis, i),
},
inputs,
),
);
}
return output;
};
export const parseSplitAttributes: OperatorInitialization<SplitAttributes> = (node: Graph.Node): SplitAttributes => {
const axis = node.attributes.getInt('axis', 0);
const split = node.attributes.getInts('split', []);
const numOutputs = node.outputs.length;
return createAttributeWithCacheKey({ axis, split, numOutputs });
};
const getProgramCount = (
_inferenceHandler: WebGLInferenceHandler,
inputs: Tensor[],
axis: number,
attributes: SplitAttributes,
): number => {
const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs);
return offsets.length;
};
const createSplitProgramInfo = (
_inferenceHandler: WebGLInferenceHandler,
input: Tensor,
attributes: SplitAttributes,
axis: number,
index: number,
): ProgramInfo => {
const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs);
const offset = offsets[index];
const outputShape = shapes[index];
const rank = outputShape.length;
const shaderSource = `
float process(int indices[${rank}]) {
indices[${axis}] += ${offset};
return _A(indices);
}
`;
return {
...splitProgramMetadata,
cacheHint: `${attributes.cacheKey}:${index}`,
output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
shaderSource,
};
};
const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Split requires one input.');
}
if (
inputs[0].type !== 'int8' &&
inputs[0].type !== 'uint8' &&
inputs[0].type !== 'int16' &&
inputs[0].type !== 'uint16' &&
inputs[0].type !== 'int32' &&
inputs[0].type !== 'uint32' &&
inputs[0].type !== 'float32' &&
inputs[0].type !== 'float64' &&
inputs[0].type !== 'bool'
) {
throw new Error('Invalid input type.');
}
};