UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

68 lines (56 loc) 2.08 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { Tensor } from '../../../tensor'; import { getGlsl } from '../glsl-source'; import { WebGLInferenceHandler } from '../inference-handler'; import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types'; import { getCoordsDataType } from '../utils'; import { getChannels, unpackFromChannel } from './packing-utils'; const unpackProgramMetadata = { name: 'unpack', inputNames: ['A'], inputTypes: [TextureType.packed], }; export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => { const rank = input.dims.length; const channels = getChannels('rc', rank); const innerDims = channels.slice(-2); const coordsDataType = getCoordsDataType(rank); const unpackChannel = unpackFromChannel(); const isScalar = input.dims.length === 0; const sourceCoords = isScalar ? '' : getSourceCoords(rank, channels); const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`; const glsl = getGlsl(handler.session.backend.glContext.version); const shaderSource = ` ${unpackChannel} void main() { ${coordsDataType} rc = getOutputCoords(); // Sample the texture with the coords to get the rgba channel value. vec4 packedInput = getA(${sourceCoords}); ${glsl.output} = vec4(getChannel(packedInput, ${coords}), 0, 0, 0); } `; return { ...unpackProgramMetadata, hasMain: true, output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked }, shaderSource, }; }; export const createUnpackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => ({ ...unpackProgramMetadata, get: () => createUnpackProgramInfo(handler, input), }); function getSourceCoords(rank: number, dims: string[]): string { if (rank === 1) { return 'rc'; } let coords = ''; for (let i = 0; i < rank; i++) { coords += dims[i]; if (i < rank - 1) { coords += ','; } } return coords; }