UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

59 lines (49 loc) 2.34 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { Tensor } from '../../../tensor'; import { WebGLInferenceHandler } from '../inference-handler'; import { calculateOutputShape, ConvAttributes } from './conv'; import { createPackedIm2ColProgramInfoLoader } from './im2col-pack'; import { createPackedMatmulProgramInfoLoader } from './matmul-pack'; export const conv2DPackedPointwise = ( inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes, ): Tensor => { const xshape = inputs[0].dims; const kshape = inputs[1].dims; const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); const reshapedX = inferenceHandler.reshapePacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]); const reshapedK = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1]]); const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX]; const matmulOutput = inferenceHandler.run( createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes), matmulInputs, ); return inferenceHandler.reshapePacked(matmulOutput, outputShape); }; export const conv2DPacked = ( inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes, ): Tensor => { const xshape = inputs[0].dims; const kshape = inputs[1].dims; const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); // run im2col const im2colOutput = inferenceHandler.run( createPackedIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes), [inputs[0]], ); // reshape kernel const kernelReshaped = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1] * kshape[2] * kshape[3]]); // run matmul const matmulInputs = inputs.length === 3 ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]; const matmulOutput = inferenceHandler.run( createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes), matmulInputs, ); // reshape output const outputReshaped = inferenceHandler.reshapePacked(matmulOutput, outputShape); return outputReshaped; };