UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

141 lines (139 loc) 5.31 kB
'use strict'; // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. Object.defineProperty(exports, '__esModule', { value: true }); exports.getBiasForMatmul = exports.createMatmulProgramInfoLoader = exports.parseMatMulAttributes = exports.matMul = void 0; const util_1 = require('../../../util'); const types_1 = require('../types'); const utils_1 = require('../utils'); const fuse_utils_1 = require('./fuse-utils'); const matmul_pack_1 = require('./matmul-pack'); const matMul = (inferenceHandler, inputs, attributes) => { validateInputs(inputs); if (inferenceHandler.session.pack) { return [ inferenceHandler.run( (0, matmul_pack_1.createPackedMatmulProgramInfoLoader)(inferenceHandler, inputs, attributes), inputs, ), ]; } else { return [inferenceHandler.run(createMatmulProgramInfoLoader(inputs, attributes), inputs)]; } }; exports.matMul = matMul; const parseMatMulAttributes = (node) => (0, fuse_utils_1.parseInternalActivationAttributes)(node.attributes); exports.parseMatMulAttributes = parseMatMulAttributes; const createMatmulProgramMetadata = (hasBias, cacheHint) => ({ name: 'MatMul', inputNames: hasBias ? ['A', 'B', 'Bias'] : ['A', 'B'], inputTypes: hasBias ? [types_1.TextureType.unpacked, types_1.TextureType.unpacked, types_1.TextureType.unpacked] : [types_1.TextureType.unpacked, types_1.TextureType.unpacked], cacheHint, }); function createMatmulProgramInfo(metadata, inputs, activationAttributes) { const aShape = inputs[0].dims; const bShape = inputs[1].dims; const outputShape = util_1.BroadcastUtil.calcShape(aShape, bShape, true); if (!outputShape) { throw new Error("Can't use matmul on the given tensors"); } const coordsDataType = (0, utils_1.getCoordsDataType)(outputShape.length); const allGlChannels = (0, utils_1.getGlChannels)(); const { activationFunction, applyActivation } = (0, fuse_utils_1.getActivationSnippet)(activationAttributes); const hasBias = inputs.length > 2; const processBias = hasBias ? 'value += getBiasForMatmul();' : ''; const getBiasForMatmulSnippet = hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}` : ''; const rank = outputShape.length; const arank = aShape.length; const brank = bShape.length; const sharedDim = aShape[aShape.length - 1]; const shaderSource = ` ${activationFunction} ${getBiasForMatmulSnippet} float process(int indices[${rank}]) { int a[${arank}]; int b[${brank}]; bcastMatmulIndices_A(indices, a); bcastMatmulIndices_B(indices, b); float value; for (int k=0; k<${sharedDim}; ++k) { a[${arank - 1}] = k; b[${brank - 2}] = k; value += _A(a) * _B(b); } ${processBias} ${applyActivation} return value; }`; return { ...metadata, output: { dims: outputShape, type: inputs[0].type, textureType: types_1.TextureType.unpacked }, shaderSource, }; } function createMatmulProgramInfoLoader(inputs, activationAttributes) { const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); return { ...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes) }; } exports.createMatmulProgramInfoLoader = createMatmulProgramInfoLoader; const validateInputs = (inputs) => { if (!inputs || inputs.length !== 2) { throw new Error('MatMul requires 2 inputs.'); } if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) { throw new Error('shared dimension does not match.'); } if ( (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') || (inputs[1].type !== 'float32' && inputs[1].type !== 'float64') ) { throw new Error('inputs should be float type'); } if (inputs[0].type !== inputs[1].type) { throw new Error('inputs types should match'); } }; function getBiasForMatmul(coordsDataType, allGlChannels, inShape, outShape, isPacked) { let unpackedCoordsSnippet = ''; const inRank = inShape.length; const outRank = outShape.length; const rankDiff = outRank - inRank; if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); } const broadcastDims = util_1.BroadcastUtil.getBroadcastDims(inShape, outShape); const coordsSnippet = broadcastDims.map((d) => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n'); const inSize = util_1.ShapeUtil.size(inShape); const isInputScalar = inSize === 1; let output = 'vec4(outputValue.xx, outputValue.yy)'; if (isInputScalar) { output = 'vec4(outputValue.x)'; } const getBiasForMatmulSource = isPacked ? ` vec4 getBiasForMatmul() { ${coordsDataType} coords = getOutputCoords(); ${coordsSnippet} vec4 outputValue = getBias(${unpackedCoordsSnippet}); return ${output}; }` : ` float getBiasForMatmul() { ${coordsDataType} coords = getOutputCoords(); ${coordsSnippet} return getBias(coords.x); }`; return getBiasForMatmulSource; } exports.getBiasForMatmul = getBiasForMatmul; //# sourceMappingURL=matmul.js.map