UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

377 lines (330 loc) 12.3 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; import { Graph } from '../../../graph'; import { OperatorImplementation, OperatorInitialization } from '../../../operators'; import { Tensor } from '../../../tensor'; import { getGlsl } from '../glsl-source'; import { WebGLInferenceHandler } from '../inference-handler'; import { ProgramInfo, TextureType } from '../types'; export interface UpsampleAttributes extends AttributeWithCacheKey { readonly opset: number; readonly isResize: boolean; readonly mode: string; readonly scales: number[]; readonly extrapolationValue: number; readonly coordinateTransformMode: string; readonly useExtrapolation: boolean; readonly needRoiInput: boolean; readonly nearestMode: string; readonly cubicCoefficientA: number; readonly excludeOutside: boolean; readonly useNearest2xOptimization: boolean; readonly roiInputIdx: number; readonly scalesInputIdx: number; readonly sizesInputIdx: number; } const upsampleProgramMetadata = { name: 'Upsample', inputNames: ['X'], inputTypes: [TextureType.unpacked], }; export const upsample: OperatorImplementation<UpsampleAttributes> = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: UpsampleAttributes, ): Tensor[] => { validateInputs(inputs, attributes); const output = inferenceHandler.run( { ...upsampleProgramMetadata, cacheHint: attributes.cacheKey, get: () => createUpsampleProgramInfo(inferenceHandler, inputs, attributes), }, inputs, ); return [output]; }; export const parseUpsampleAttributesV7: OperatorInitialization<UpsampleAttributes> = ( node: Graph.Node, ): UpsampleAttributes => parseUpsampleAttributes(node, 7); export const parseUpsampleAttributesV9: OperatorInitialization<UpsampleAttributes> = ( node: Graph.Node, ): UpsampleAttributes => parseUpsampleAttributes(node, 9); export const parseUpsampleAttributes = (node: Graph.Node, opset: number): UpsampleAttributes => { const isResize = opset >= 10; // processing node attributes const mode = node.attributes.getString('mode', 'nearest'); if (mode !== 'nearest' && mode !== 'linear' && (opset < 11 || mode !== 'cubic')) { throw new Error(`unrecognized mode: ${mode}`); } let scales: number[] = []; if (opset < 9) { scales = node.attributes.getFloats('scales'); scalesValidation(scales, mode, isResize); } const extrapolationValue = node.attributes.getFloat('extrapolation_value', 0.0); const coordinateTransformMode = opset > 10 ? node.attributes.getString('coordinate_transformation_mode', 'half_pixel') : 'asymmetric'; if ( [ 'asymmetric', 'pytorch_half_pixel', 'tf_half_pixel_for_nn', 'align_corners', 'tf_crop_and_resize', 'half_pixel', ].indexOf(coordinateTransformMode) === -1 ) { throw new Error(`coordinate_transform_mode '${coordinateTransformMode}' is not supported`); } const needRoiInput = coordinateTransformMode === 'tf_crop_and_resize'; const useExtrapolation = needRoiInput; const nearestMode = mode === 'nearest' && opset >= 11 ? node.attributes.getString('nearest_mode', 'round_prefer_floor') : ''; if (['round_prefer_floor', 'round_prefer_ceil', 'floor', 'ceil', ''].indexOf(nearestMode) === -1) { throw new Error(`nearest_mode '${nearestMode}' is not supported`); } const cubicCoefficientA = node.attributes.getFloat('cubic_coeff_a', -0.75); const excludeOutside = node.attributes.getInt('exclude_outside', 0) !== 0; if (excludeOutside && mode !== 'cubic') { throw new Error('exclude_outside can be set to 1 only when mode is CUBIC.'); } const useNearest2xOptimization = opset < 11 ? true : mode === 'nearest' && coordinateTransformMode === 'asymmetric' && nearestMode === 'floor'; let roiInputIdx = 0; let scalesInputIdx = 0; let sizesInputIdx = 0; if (opset > 10) { // handle when roiInput is not given if (node.inputs.length > 2) { roiInputIdx = 1; scalesInputIdx = 2; sizesInputIdx = 3; } else { scalesInputIdx = 1; sizesInputIdx = 2; } } else if (opset === 9) { scalesInputIdx = 1; } return createAttributeWithCacheKey({ opset, isResize, mode, scales, extrapolationValue, coordinateTransformMode, useExtrapolation, needRoiInput, nearestMode, cubicCoefficientA, excludeOutside, useNearest2xOptimization, roiInputIdx, scalesInputIdx, sizesInputIdx, }); }; const createUpsampleProgramInfo = ( inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: UpsampleAttributes, ): ProgramInfo => { const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const [inputWidth, inputHeight] = inferenceHandler.calculateTextureWidthAndHeight( inputs[0].dims, TextureType.unpacked, ); const outputShape = inputs[0].dims.map((dim, i) => Math.floor(dim * attributes.scales[i])); const [outputWidth, outputHeight] = inferenceHandler.calculateTextureWidthAndHeight( outputShape, TextureType.unpacked, ); const dim = outputShape.length; const outputPitches = new Array<number>(dim); const inputPitches = new Array<number>(dim); let precalculatedPitches = ` int output_pitches[${dim}]; int input_pitches[${dim}]; `; for (let d = dim - 1; d >= 0; d--) { outputPitches[d] = d === dim - 1 ? 1 : outputPitches[d + 1] * outputShape[d + 1]; inputPitches[d] = d === dim - 1 ? 1 : inputPitches[d + 1] * inputs[0].dims[d + 1]; precalculatedPitches += ` output_pitches[${d}] = ${outputPitches[d]}; input_pitches[${d}] = ${inputPitches[d]}; `; } const getInputFloatFunction = ` float getInputFloat(int index) { vec2 coords = offsetToCoords(index, ${inputWidth}, ${inputHeight}); float value = getColorAsFloat(${glsl.texture2D}(X, coords)); return value; } `; const shaderSource = attributes.mode === 'nearest' ? // nearest ` ${getInputFloatFunction} float process(int indices[${dim}]) { int input_index = 0; int output_index = coordsToOffset(TexCoords, ${outputWidth}, ${outputHeight}); ${precalculatedPitches} int d, m; for (int dim = 0; dim < ${dim}; ++dim) { d = output_index / output_pitches[dim]; m = output_index - d * output_pitches[dim]; output_index = m; if (scales[dim] != 1 && d > 0) { int d2 = d / scales[dim]; m = d - d2 * scales[dim]; d = d2; } input_index += input_pitches[dim] * d; } return getInputFloat(input_index); }` : dim === 4 ? // bilinear 4D ` ${getInputFloatFunction} float process(int indices[4]) { int input_index = 0; int output_index = coordsToOffset(TexCoords, ${outputWidth}, ${outputHeight}); ${precalculatedPitches} int m; int index_of_dim0, index_of_dim1, index_of_dim2, index_of_dim3; index_of_dim0 = output_index / output_pitches[0]; m = output_index - index_of_dim0 * output_pitches[0]; index_of_dim1 = m / output_pitches[1]; m = m - index_of_dim1 * output_pitches[1]; index_of_dim2 = m / output_pitches[2]; m = m - index_of_dim2 * output_pitches[2]; index_of_dim3 = m; int index_of_input_dim2, index_of_input_dim3, x_offset, y_offset; index_of_input_dim2 = index_of_dim2 / scales[2]; y_offset = index_of_dim2 - index_of_input_dim2 * scales[2]; index_of_input_dim3 = index_of_dim3 / scales[3]; x_offset = index_of_dim3 - index_of_input_dim3 * scales[3]; input_index = index_of_dim0 * input_pitches[0] + index_of_dim1 * input_pitches[1] + index_of_input_dim2 * input_pitches[2] + index_of_input_dim3; float x00 = getInputFloat(input_index); float x10, x01, x11; bool end_of_dim2 = false; if (index_of_input_dim2 == (${inputs[0].dims[2]} - 1)) { // It's the end in dimension 2 x01 = x00; end_of_dim2 = true; } else { x01 = getInputFloat(input_index + input_pitches[2]); } if (index_of_input_dim3 == (input_pitches[2] - 1)) { // It's the end in dimension 3 x10 = x00; x11 = x01; } else { x10 = getInputFloat(input_index + 1); x11 = end_of_dim2 ? x10 : getInputFloat(input_index + input_pitches[2] + 1); } float y0 = x00 + float(y_offset) * (x01 - x00) / float(scales[2]); float y1 = x10 + float(y_offset) * (x11 - x10) / float(scales[2]); return y0 + float(x_offset) * (y1 - y0) / float(scales[3]); }` : // bilinear 2D ` ${getInputFloatFunction} float process(int indices[2]) { int input_index = 0; int output_index = coordsToOffset(TexCoords, ${outputWidth}, ${outputHeight}); ${precalculatedPitches} int m; int index_of_dim0, index_of_dim1; index_of_dim0 = output_index / output_pitches[0]; m = output_index - index_of_dim0 * output_pitches[0]; index_of_dim1 = m; int index_of_input_dim0, index_of_input_dim1, x_offset, y_offset; index_of_input_dim0 = index_of_dim0 / scales[0]; y_offset = index_of_dim0 - index_of_input_dim0 * scales[0]; index_of_input_dim1 = index_of_dim1 / scales[1]; x_offset = index_of_dim1 - index_of_input_dim1 * scales[1]; input_index = index_of_input_dim0 * input_pitches[0] + index_of_input_dim1; float x00 = getInputFloat(input_index); float x10, x01, x11; bool end_of_dim0 = false; if (index_of_input_dim0 == (${inputs[0].dims[0]} - 1)) { // It's the end in dimension 0 x01 = x00; end_of_dim0 = true; } else { x01 = getInputFloat(input_index + input_pitches[0]); } if (index_of_input_dim1 == (input_pitches[0] - 1)) { // It's the end in dimension 1 x10 = x00; x11 = x01; } else { x10 = getInputFloat(input_index + 1); x11 = end_of_dim0 ? x10 : getInputFloat(input_index + input_pitches[0] + 1); } float y0 = x00 + float(y_offset) * (x01 - x00) / float(scales[0]); float y1 = x10 + float(y_offset) * (x11 - x10) / float(scales[0]); return y0 + float(x_offset) * (y1 - y0) / float(scales[1]); }`; return { ...upsampleProgramMetadata, output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, shaderSource, variables: [ { name: 'scales', type: 'int', arrayLength: attributes.scales.length, data: attributes.scales.map((x) => Math.ceil(x)), }, ], }; }; export const validateInputs = (inputs: Tensor[], attribute: UpsampleAttributes): void => { if ( !inputs || (attribute.opset < 9 && inputs.length !== 1) || (attribute.opset >= 9 && attribute.opset < 11 && inputs.length !== 2) || (attribute.opset >= 11 && inputs.length < 2) ) { throw new Error('invalid inputs.'); } if (attribute.scales.length > 0 && inputs[0].dims.length !== attribute.scales.length) { throw new Error('Invalid input shape.'); } if (inputs[0].type === 'string') { throw new Error('Invalid input tensor types.'); } }; export const scalesValidation = (scales: number[], mode: string, isResize: boolean): void => { if (!isResize) { for (const scale of scales) { if (scale < 1) { throw new Error('Scale value should be greater than or equal to 1.'); } } } else { for (const scale of scales) { if (scale <= 0) { throw new Error('Scale value should be greater than 0.'); } } } if (mode === 'linear' || mode === 'cubic') { if (scales.length !== 2 && (scales.length !== 4 || scales[0] !== 1 || scales[1] !== 1)) { throw new Error(`'Linear' mode and 'Cubic' mode only support 2-D inputs ('Bilinear', 'Bicubic') \ or 4-D inputs with the corresponding outermost 2 scale values being 1 \ in the ${isResize ? 'Resize' : 'Upsample'} opeartor.`); } } };