onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
182 lines (181 loc) • 8.03 kB
JavaScript
;
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
Object.defineProperty(exports, '__esModule', { value: true });
exports.parseConvAttributes = exports.conv = exports.calculateOutputShape = void 0;
const attribute_with_cache_key_1 = require('../../../attribute-with-cache-key');
const util_1 = require('../../../util');
const conv_grouped_1 = require('./conv-grouped');
const conv_pack_1 = require('./conv-pack');
const dot_product_1 = require('./dot-product');
const fuse_utils_1 = require('./fuse-utils');
const im2col_1 = require('./im2col');
const matmul_1 = require('./matmul');
const calculateOutputShape = (inputShape, kernelShape, dilations, adjustPads, strides) => {
const batchSize = inputShape[0];
const inputSpatialShape = inputShape.slice(2);
const spatialRank = inputSpatialShape.length;
const outChannels = kernelShape[0];
const kernelSpatialShape = kernelShape.slice(2);
const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1));
const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]);
const outputSpatialShape = inputSpatialShapeWithPad.map((v, i) =>
Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]),
);
const outputShape = [batchSize, outChannels].concat(...outputSpatialShape);
return outputShape;
};
exports.calculateOutputShape = calculateOutputShape;
const conv = (inferenceHandler, inputs, attributes) => {
validateInputs(inputs, attributes); // currently will fail if not conv2D
return conv2d(inferenceHandler, inputs, attributes);
};
exports.conv = conv;
const conv2d = (inferenceHandler, inputs, attributes) => {
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);
const packMode = inferenceHandler.session.pack;
const isPointwise = adjustedAttributes.kernelShape[0] === 1 && adjustedAttributes.kernelShape[1] === 1;
if (adjustedAttributes.group > 1) {
const result = inferenceHandler.run(
(0, conv_grouped_1.createUnpackedGroupedConvProgramInfoLoader)(inferenceHandler, inputs, adjustedAttributes),
inputs,
);
return [result];
} else if (isPointwise && packMode) {
return [conv2DUnpackedPointwise(inferenceHandler, inputs, adjustedAttributes)];
} else if (packMode && inputs[0].dims.length === 4 && inputs[0].dims[0] === 1 && !isPointwise) {
return [(0, conv_pack_1.conv2DPacked)(inferenceHandler, inputs, adjustedAttributes)];
} else {
return [conv2DUnpacked(inferenceHandler, inputs, adjustedAttributes)];
}
};
const conv2DUnpackedPointwise = (inferenceHandler, inputs, attributes) => {
const xshape = inputs[0].dims;
const kshape = inputs[1].dims;
const outputShape = (0, exports.calculateOutputShape)(
xshape,
kshape,
attributes.dilations,
attributes.pads,
attributes.strides,
);
const reshapedX = inferenceHandler.reshapeUnpacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]);
const reshapedK = inferenceHandler.reshapeUnpacked(inputs[1], [kshape[0], kshape[1]]);
const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX];
const matmulOutput = inferenceHandler.run(
(0, matmul_1.createMatmulProgramInfoLoader)(matmulInputs, attributes),
matmulInputs,
);
return inferenceHandler.reshapeUnpacked(matmulOutput, outputShape);
};
const conv2DUnpacked = (inferenceHandler, inputs, attributes) => {
const xshape = inputs[0].dims;
const kshape = inputs[1].dims;
const outputShape = (0, exports.calculateOutputShape)(
xshape,
kshape,
attributes.dilations,
attributes.pads,
attributes.strides,
);
const xIm2Col = inferenceHandler.run(
(0, im2col_1.createIm2ColProgramInfoLoader)(inferenceHandler, inputs[0], inputs[1], outputShape, attributes),
[inputs[0]],
);
const dotProductInputs = inputs.length === 3 ? [xIm2Col, inputs[1], inputs[2]] : [xIm2Col, inputs[1]];
const output = inferenceHandler.run(
(0, dot_product_1.createDotProductProgramInfoLoader)(inferenceHandler, inputs, outputShape, attributes),
dotProductInputs,
);
return output;
};
const getAdjustedConvAttributes = (attributes, inputs) => {
const kernelShape = attributes.kernelShape.slice();
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
if (attributes.kernelShape.length === 0) {
for (let i = 2; i < inputs[1].dims.length; ++i) {
kernelShape.push(inputs[1].dims[i]);
}
}
const pads = attributes.pads.slice();
util_1.PoolConvUtil.adjustPadsBasedOnAutoPad(
inputs[0].dims,
attributes.strides,
attributes.dilations,
kernelShape,
pads,
attributes.autoPad,
);
// always return a new object so does not modify the original attributes
const newAttributes = Object.assign({}, attributes);
Object.assign(newAttributes, { kernelShape, pads, cacheKey: attributes.cacheKey });
return newAttributes;
};
const parseConvAttributes = (node) => {
const attributes = node.attributes;
const activationAttributes = (0, fuse_utils_1.parseInternalActivationAttributes)(attributes);
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv
const autoPad = attributes.getString('auto_pad', 'NOTSET');
const dilations = attributes.getInts('dilations', [1, 1]);
const group = attributes.getInt('group', 1);
const kernelShape = attributes.getInts('kernel_shape', []);
const pads = attributes.getInts('pads', [0, 0, 0, 0]);
const strides = attributes.getInts('strides', [1, 1]);
return (0, attribute_with_cache_key_1.createAttributeWithCacheKey)({
autoPad,
dilations,
group,
kernelShape,
pads,
strides,
...activationAttributes,
});
};
exports.parseConvAttributes = parseConvAttributes;
const validateInputs = (inputs, attributes) => {
// Refer to the below link for all input checks
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
throw new Error('Conv requires 2 or 3 inputs');
}
// TODO : Need to add support for multi-dimensional conv
if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) {
throw new Error('currently only support 2-dimensional conv');
}
// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
const dataChannel = inputs[0].dims[1];
const filterInChannel = inputs[1].dims[1] * attributes.group;
if (dataChannel !== filterInChannel) {
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
}
// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) {
throw new Error('invalid bias');
}
const spatialRank = inputs[0].dims.length - 2;
// wrong dilations dimension
if (attributes.dilations.length !== spatialRank) {
throw new Error(`dilations should be ${spatialRank}D`);
}
// Wrong strides dimension
if (attributes.strides.length !== spatialRank) {
throw new Error(`strides should be ${spatialRank}D`);
}
// Wrong pads dimension
if (attributes.pads.length !== spatialRank * 2) {
throw new Error(`pads should be ${spatialRank * 2}D`);
}
// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
// (the first 2 dims are batch_size and channels)
if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
throw new Error('invalid kernel shape');
}
// TODO : Need to add support for float64
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') {
throw new Error('Conv input(X,W) should be float tensor');
}
if (inputs.length === 3 && inputs[2].type !== 'float32') {
throw new Error('Conv input(bias) should be float tensor');
}
};
//# sourceMappingURL=conv.js.map