UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

219 lines (194 loc) 8.91 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; import { InferenceHandler } from '../../../backend'; import { Graph } from '../../../graph'; import { OperatorImplementation, OperatorInitialization } from '../../../operators'; import { Tensor } from '../../../tensor'; import { PoolConvUtil } from '../../../util'; import { WebGLInferenceHandler } from '../inference-handler'; import { createUnpackedGroupedConvProgramInfoLoader } from './conv-grouped'; import { conv2DPacked } from './conv-pack'; import { createDotProductProgramInfoLoader } from './dot-product'; import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils'; import { createIm2ColProgramInfoLoader } from './im2col'; import { createMatmulProgramInfoLoader } from './matmul'; export const calculateOutputShape = ( inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], adjustPads: readonly number[], strides: readonly number[], ): number[] => { const batchSize = inputShape[0]; const inputSpatialShape = inputShape.slice(2); const spatialRank = inputSpatialShape.length; const outChannels = kernelShape[0]; const kernelSpatialShape = kernelShape.slice(2); const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); const outputSpatialShape = inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]), ); const outputShape = [batchSize, outChannels].concat(...outputSpatialShape); return outputShape; }; export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey { readonly autoPad: string; readonly dilations: readonly number[]; readonly group: number; readonly kernelShape: readonly number[]; readonly pads: readonly number[]; readonly strides: readonly number[]; } export const conv: OperatorImplementation<ConvAttributes> = ( inferenceHandler: InferenceHandler, inputs: Tensor[], attributes: ConvAttributes, ): Tensor[] => { validateInputs(inputs, attributes); // currently will fail if not conv2D return conv2d(inferenceHandler, inputs, attributes); }; const conv2d: OperatorImplementation<ConvAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ConvAttributes, ): Tensor[] => { const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs); const packMode = inferenceHandler.session.pack; const isPointwise = adjustedAttributes.kernelShape[0] === 1 && adjustedAttributes.kernelShape[1] === 1; if (adjustedAttributes.group > 1) { const result = inferenceHandler.run( createUnpackedGroupedConvProgramInfoLoader(inferenceHandler, inputs, adjustedAttributes), inputs, ); return [result]; } else if (isPointwise && packMode) { return [conv2DUnpackedPointwise(inferenceHandler, inputs, adjustedAttributes)]; } else if (packMode && inputs[0].dims.length === 4 && inputs[0].dims[0] === 1 && !isPointwise) { return [conv2DPacked(inferenceHandler, inputs, adjustedAttributes)]; } else { return [conv2DUnpacked(inferenceHandler, inputs, adjustedAttributes)]; } }; const conv2DUnpackedPointwise = ( inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes, ): Tensor => { const xshape = inputs[0].dims; const kshape = inputs[1].dims; const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); const reshapedX = inferenceHandler.reshapeUnpacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]); const reshapedK = inferenceHandler.reshapeUnpacked(inputs[1], [kshape[0], kshape[1]]); const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX]; const matmulOutput = inferenceHandler.run(createMatmulProgramInfoLoader(matmulInputs, attributes), matmulInputs); return inferenceHandler.reshapeUnpacked(matmulOutput, outputShape); }; const conv2DUnpacked = ( inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes, ): Tensor => { const xshape = inputs[0].dims; const kshape = inputs[1].dims; const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); const xIm2Col = inferenceHandler.run( createIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes), [inputs[0]], ); const dotProductInputs = inputs.length === 3 ? [xIm2Col, inputs[1], inputs[2]] : [xIm2Col, inputs[1]]; const output = inferenceHandler.run( createDotProductProgramInfoLoader(inferenceHandler, inputs, outputShape, attributes), dotProductInputs, ); return output; }; const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: Tensor[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims if (attributes.kernelShape.length === 0) { for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); } } const pads = attributes.pads.slice(); PoolConvUtil.adjustPadsBasedOnAutoPad( inputs[0].dims, attributes.strides, attributes.dilations, kernelShape, pads, attributes.autoPad, ); // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); Object.assign(newAttributes, { kernelShape, pads, cacheKey: attributes.cacheKey }); return newAttributes; }; export const parseConvAttributes: OperatorInitialization<ConvAttributes> = (node: Graph.Node): ConvAttributes => { const attributes = node.attributes; const activationAttributes = parseInternalActivationAttributes(attributes); // TODO : Make this generic enough to compute default attributes for multi-dimensional conv const autoPad = attributes.getString('auto_pad', 'NOTSET'); const dilations = attributes.getInts('dilations', [1, 1]); const group = attributes.getInt('group', 1); const kernelShape = attributes.getInts('kernel_shape', []); const pads = attributes.getInts('pads', [0, 0, 0, 0]); const strides = attributes.getInts('strides', [1, 1]); return createAttributeWithCacheKey({ autoPad, dilations, group, kernelShape, pads, strides, ...activationAttributes, }); }; const validateInputs = (inputs: Tensor[], attributes: ConvAttributes): void => { // Refer to the below link for all input checks // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) { throw new Error('Conv requires 2 or 3 inputs'); } // TODO : Need to add support for multi-dimensional conv if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) { throw new Error('currently only support 2-dimensional conv'); } // FILTER_IN_CHANNEL should be equal to DATA_CHANNEL const dataChannel = inputs[0].dims[1]; const filterInChannel = inputs[1].dims[1] * attributes.group; if (dataChannel !== filterInChannel) { throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL'); } // if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) { throw new Error('invalid bias'); } const spatialRank = inputs[0].dims.length - 2; // wrong dilations dimension if (attributes.dilations.length !== spatialRank) { throw new Error(`dilations should be ${spatialRank}D`); } // Wrong strides dimension if (attributes.strides.length !== spatialRank) { throw new Error(`strides should be ${spatialRank}D`); } // Wrong pads dimension if (attributes.pads.length !== spatialRank * 2) { throw new Error(`pads should be ${spatialRank * 2}D`); } // if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor // (the first 2 dims are batch_size and channels) if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { throw new Error('invalid kernel shape'); } // TODO : Need to add support for float64 if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') { throw new Error('Conv input(X,W) should be float tensor'); } if (inputs.length === 3 && inputs[2].type !== 'float32') { throw new Error('Conv input(bias) should be float tensor'); } };