UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

71 lines (62 loc) 2.18 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { NUMBER_TYPES } from '../../../operators'; import { Tensor } from '../../../tensor'; import { WebGLInferenceHandler } from '../inference-handler'; import { ProgramInfo, ProgramMetadata, TextureType } from '../types'; export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputs(inputs); const tileProgramMetadata = { name: 'Tile', inputNames: ['A'], inputTypes: [TextureType.unpacked], }; const output = inferenceHandler.run( { ...tileProgramMetadata, get: () => createTileProgramInfo(inferenceHandler, inputs, tileProgramMetadata) }, inputs, ); return [output]; }; const createTileProgramInfo = ( _handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata, ): ProgramInfo => { const inputShape = inputs[0].dims.slice(); const outputShape = new Array(inputShape.length); const tileOps: string[] = []; for (let i = 0; i < inputShape.length; i++) { outputShape[i] = inputShape[i] * inputs[1].numberData[i]; tileOps.push(`inputIdx[${i}] = int(mod(float(outputIdx[${i}]), ${inputShape[i]}.));`); } const rank = outputShape.length; const shaderSource = ` float process(int outputIdx[${rank}]) { int inputIdx[${rank}]; ${tileOps.join('\n')} return _A(inputIdx); } `; return { ...tileProgramMetadata, output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, shaderSource, }; }; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 2) { throw new Error('Tile requires 2 input.'); } if (inputs[1].dims.length !== 1) { throw new Error('The second input shape must 1 dimension.'); } if (inputs[1].dims[0] !== inputs[0].dims.length) { throw new Error('Invalid input shape.'); } if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) { throw new Error('Invalid input type.'); } if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') { throw new Error('Invalid repeat type.'); } };