UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

85 lines (74 loc) 3.05 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { env } from 'onnxruntime-common'; import { DataType } from '../../../wasm-common'; import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; import { createTensorShapeVariables, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType, } from './common'; const validateInputsContent = (start: number, limit: number, delta: number): void => { const sameStartLimit = start === limit; const increasingRangeNegativeStep = start < limit && delta < 0; const decreasingRangePositiveStep = start > limit && delta > 0; if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) { throw new Error("Range these inputs' contents are invalid."); } }; const createRangeProgramInfo = (start: number, limit: number, delta: number, dataType: DataType): ProgramInfo => { const numElements = Math.abs(Math.ceil((limit - start) / delta)); const outputShape: number[] = [numElements]; const outputSize = numElements; const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, { type: dataType, data: start }, { type: dataType, data: delta }, ...createTensorShapeVariables(outputShape), ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', dataType, outputShape.length); const wgslType = output.type.value; const uniforms: UniformsArrayType = [ { name: 'outputSize', type: 'u32' }, { name: 'start', type: wgslType as UniformDataElementType }, { name: 'delta', type: wgslType as UniformDataElementType }, ]; return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} output[global_idx] = uniforms.start + ${wgslType}(global_idx) * uniforms.delta; }`; }; return { name: 'Range', shaderCache: { hint: `${dataType}` }, getShaderSource, getRunData: () => ({ outputs: [{ dims: outputShape, dataType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms, }), }; }; export const range = (context: ComputeContext): void => { let start = 0; let limit = 0; let delta = 0; if (context.inputs[0].dataType === DataType.int32) { start = context.inputs[0].getInt32Array()[0]; limit = context.inputs[1].getInt32Array()[0]; delta = context.inputs[2].getInt32Array()[0]; } else if (context.inputs[0].dataType === DataType.float) { start = context.inputs[0].getFloat32Array()[0]; limit = context.inputs[1].getFloat32Array()[0]; delta = context.inputs[2].getFloat32Array()[0]; } if (env.webgpu.validateInputContent) { validateInputsContent(start, limit, delta); } context.compute(createRangeProgramInfo(start, limit, delta, context.inputs[0].dataType), { inputs: [] }); };