onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
68 lines (56 loc) • 2.08 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { Tensor } from '../../../tensor';
import { getGlsl } from '../glsl-source';
import { WebGLInferenceHandler } from '../inference-handler';
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
import { getCoordsDataType } from '../utils';
import { getChannels, unpackFromChannel } from './packing-utils';
const unpackProgramMetadata = {
name: 'unpack',
inputNames: ['A'],
inputTypes: [TextureType.packed],
};
export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => {
const rank = input.dims.length;
const channels = getChannels('rc', rank);
const innerDims = channels.slice(-2);
const coordsDataType = getCoordsDataType(rank);
const unpackChannel = unpackFromChannel();
const isScalar = input.dims.length === 0;
const sourceCoords = isScalar ? '' : getSourceCoords(rank, channels);
const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
const glsl = getGlsl(handler.session.backend.glContext.version);
const shaderSource = `
${unpackChannel}
void main() {
${coordsDataType} rc = getOutputCoords();
// Sample the texture with the coords to get the rgba channel value.
vec4 packedInput = getA(${sourceCoords});
${glsl.output} = vec4(getChannel(packedInput, ${coords}), 0, 0, 0);
}
`;
return {
...unpackProgramMetadata,
hasMain: true,
output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked },
shaderSource,
};
};
export const createUnpackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => ({
...unpackProgramMetadata,
get: () => createUnpackProgramInfo(handler, input),
});
function getSourceCoords(rank: number, dims: string[]): string {
if (rank === 1) {
return 'rc';
}
let coords = '';
for (let i = 0; i < rank; i++) {
coords += dims[i];
if (i < rank - 1) {
coords += ',';
}
}
return coords;
}