UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

297 lines (269 loc) 11.3 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import type { Env } from 'onnxruntime-common'; import { calculateTensorSizeInBytes, DataType } from '../wasm-common'; import type { OrtWasmModule } from '../wasm-types'; import type { WebGpuBackend } from './backend-webgpu'; import { LOG_DEBUG } from './log'; import type { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; import type { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ class TensorViewImpl implements TensorView { constructor( private module: OrtWasmModule, public readonly dataType: number, public readonly data: number, public readonly dims: readonly number[], ) {} getFloat32Array(): Float32Array { if (this.dataType !== DataType.float) { throw new Error('Invalid data type'); } const elementCount = ShapeUtil.size(this.dims); return elementCount === 0 ? new Float32Array() : new Float32Array(this.module.HEAP8.buffer, this.data, elementCount); } getBigInt64Array(): BigInt64Array { if (this.dataType !== DataType.int64) { throw new Error('Invalid data type'); } const elementCount = ShapeUtil.size(this.dims); return elementCount === 0 ? new BigInt64Array() : new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount); } getInt32Array(): Int32Array { if (this.dataType !== DataType.int32) { throw new Error('Invalid data type'); } const elementCount = ShapeUtil.size(this.dims); return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount); } getUint16Array(): Uint16Array { if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) { throw new Error('Invalid data type'); } const elementCount = ShapeUtil.size(this.dims); return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount); } reshape(newDims: readonly number[]): TensorView { if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) { throw new Error('Invalid new shape'); } return new TensorViewImpl(this.module, this.dataType, this.data, newDims); } } class ComputeContextImpl implements ComputeContext { readonly adapterInfo: AdapterInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; get kernelCustomData(): { [key: string]: unknown } { return this.backend.currentKernelCustomData; } get customDataBuffer(): Uint8Array { return this.module.HEAPU8.subarray(this.customDataOffset, this.customDataOffset + this.customDataSize); } private customDataOffset = 0; private customDataSize = 0; constructor( private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number, ) { this.adapterInfo = backend.adapterInfo; // extract context data const ptrSize = module.PTR_SIZE; let dataIndex = contextDataOffset / module.PTR_SIZE; const type = ptrSize === 4 ? 'i32' : 'i64'; this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type)); const inputCount = Number(module.getValue(ptrSize * dataIndex++, type)); this.outputCount = Number(module.getValue(ptrSize * dataIndex++, type)); this.customDataOffset = Number(module.getValue(ptrSize * dataIndex++, '*')); this.customDataSize = Number(module.getValue(ptrSize * dataIndex++, type)); const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { const dataType = Number(module.getValue(ptrSize * dataIndex++, type)); const data = Number(module.getValue(ptrSize * dataIndex++, '*')); const dim = Number(module.getValue(ptrSize * dataIndex++, type)); const dims: number[] = []; for (let d = 0; d < dim; d++) { dims.push(Number(module.getValue(ptrSize * dataIndex++, type))); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } this.inputs = inputs; } compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] { // prepare inputs. inputs should always be valid data. const mappedInputs = inputsOutputsMapping?.inputs?.map((i) => (typeof i === 'number' ? this.inputs[i] : i)) ?? this.inputs; // prepare outputs. const outputIndices = inputsOutputsMapping?.outputs ?? []; const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView => new TensorViewImpl(this.module, dataType, this.output(index, dims), dims); const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => { const bufferSize = calculateTensorSizeInBytes(dataType, dims); if (!bufferSize) { throw new Error(`Unsupported data type: ${dataType}`); } const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0; return new TensorViewImpl(this.module, dataType, gpuDataId, dims); }; return this.backend.run( program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput, this.outputCount, ); } output(index: number, dims: readonly number[]): number { const stack = this.module.stackSave(); try { const ptrSize = this.module.PTR_SIZE; const type = ptrSize === 4 ? 'i32' : 'i64'; const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */); this.module.setValue(data, dims.length, type); for (let i = 0; i < dims.length; i++) { this.module.setValue(data + ptrSize * (i + 1), dims[i], type); } return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { throw new Error( `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + 'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' + `Error: ${e}`, ); } finally { this.module.stackRestore(stack); } } } /** * Initialize JSEP with WebGPU backend. * * This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for * each of the following EPs if they are specified: * - "webgpu" * - "webnn" * * For WebGPU, this function expects: * - WebGPU is enabled in build (BUILD_DEFS.DISABLE_JSEP === false). * - WebGPU is available in current environment. (a valid GPUAdapter is passed in) * * For WebNN, this function expects: * - WebNN is enabled in build (BUILD_DEFS.DISABLE_JSEP === false). * - WebNN is available in current environment. (navigator.ml is not undefined) * * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate * 'webgpu'/'webnn' backend. * * @param name - the name of the EP, either "webgpu" or "webnn" * @param module - the ORT WebAssembly module * @param env - the ORT environment variable (ort.env) * @param gpuAdapter - the pre-created GPU adapter */ export const init = async ( name: 'webgpu' | 'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter, ): Promise<void> => { const jsepInit = module.jsepInit; if (!jsepInit) { throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.'); } if (name === 'webgpu') { if (!BUILD_DEFS.USE_WEBGPU_EP) { // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const webGpuBackendImpl = require('./backend-webgpu').WebGpuBackend; const backend = new webGpuBackendImpl(); await backend.initialize(env, gpuAdapter!); jsepInit('webgpu', [ // backend backend, // jsepAlloc() (size: number) => backend.alloc(Number(size)), // jsepFree() (ptr: number) => backend.free(ptr), // jsepCopy(src, dst, size, isSourceGpu) (src: number, dst: number, size: number, isSourceGpu = false) => { if (isSourceGpu) { LOG_DEBUG( 'verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, ); backend.memcpy(Number(src), Number(dst)); } else { LOG_DEBUG( 'verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, ); const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); backend.upload(Number(dst), data); } }, // jsepCopyAsync(src, dst, size) async (gpuDataId: number, dataOffset: number, size: number): Promise<void> => { LOG_DEBUG( 'verbose', () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, ); await backend.download(Number(gpuDataId), () => module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), ); }, // jsepCreateKernel (kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel( kernelType, Number(kernelId), attribute, module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), ), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), // jsepRun (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string | null>>) => { LOG_DEBUG( 'verbose', () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, ); const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); return backend.computeKernel(Number(kernel), context, errors); }, // jsepCaptureBegin () => backend.captureBegin(), // jsepCaptureEnd () => backend.captureEnd(), // jsepReplay () => backend.replay(), ]); } } else { const backend = new WebNNBackend(env); jsepInit('webnn', [ backend, // jsepReserveTensorId () => backend.reserveTensorId(), // jsepReleaseTensorId, (tensorId: number) => backend.releaseTensorId(tensorId), // jsepEnsureTensor async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) => backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld), // jsepUploadTensor (tensorId: number, data: Uint8Array) => { backend.uploadTensor(tensorId, data); }, // jsepDownloadTensor async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(tensorId, dstBuffer), ]); } };