UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

243 lines (204 loc) 8.13 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import * as path from 'node:path'; import {Env} from 'onnxruntime-common'; import {OrtWasmModule} from './binding/ort-wasm'; import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>; if (!BUILD_DEFS.DISABLE_TRAINING) { ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); } else { ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); } const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ? (BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') : require('./binding/ort-wasm-simd-threaded.jsep.js')) : ortWasmFactory; /* eslint-enable @typescript-eslint/no-require-imports */ let wasm: OrtWasmModule|undefined; let initialized = false; let initializing = false; let aborted = false; const isMultiThreadSupported = (numThreads: number): boolean => { // WebAssembly threads are set to 1 (single thread). if (numThreads === 1) { return false; } // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. if (typeof SharedArrayBuffer === 'undefined') { if (typeof self !== 'undefined' && !self.crossOriginIsolated) { // eslint-disable-next-line no-console console.warn( 'env.wasm.numThreads is set to ' + numThreads + ', but this will not work unless you enable crossOriginIsolated mode. ' + 'See https://web.dev/cross-origin-isolation-guide/ for more info.'); } return false; } // onnxruntime-web does not support multi-threads in Node.js. if (typeof process !== 'undefined' && process.versions && process.versions.node) { // eslint-disable-next-line no-console console.warn( 'env.wasm.numThreads is set to ' + numThreads + ', however, currently onnxruntime-web does not support multi-threads in Node.js. ' + 'Please consider using onnxruntime-node for performance critical scenarios.'); } try { // Test for transferability of SABs (for browsers. needed for Firefox) // https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ if (typeof MessageChannel !== 'undefined') { new MessageChannel().port1.postMessage(new SharedArrayBuffer(1)); } // Test for WebAssembly threads capability (for both browsers and Node.js) // This typed array is a WebAssembly program containing threaded instructions. return WebAssembly.validate(new Uint8Array([ 0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 5, 4, 1, 3, 1, 1, 10, 11, 1, 9, 0, 65, 0, 254, 16, 2, 0, 26, 11 ])); } catch (e) { return false; } }; const isSimdSupported = (): boolean => { try { // Test for WebAssembly SIMD capability (for both browsers and Node.js) // This typed array is a WebAssembly program containing SIMD instructions. // The binary data is generated from the following code by wat2wasm: // // (module // (type $t0 (func)) // (func $f0 (type $t0) // (drop // (i32x4.dot_i16x8_s // (i8x16.splat // (i32.const 0)) // (v128.const i32x4 0x00000000 0x00000000 0x00000000 0x00000000))))) return WebAssembly.validate(new Uint8Array([ 0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 10, 30, 1, 28, 0, 65, 0, 253, 15, 253, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 186, 1, 26, 11 ])); } catch (e) { return false; } }; const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { if (useSimd) { if (!BUILD_DEFS.DISABLE_TRAINING) { return 'ort-training-wasm-simd.wasm'; } return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; } else { return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm'; } }; export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise<void> => { if (initialized) { return Promise.resolve(); } if (initializing) { throw new Error('multiple calls to \'initializeWebAssembly()\' detected.'); } if (aborted) { throw new Error('previous call to \'initializeWebAssembly()\' failed.'); } initializing = true; // wasm flags are already initialized const timeout = flags.initTimeout!; const numThreads = flags.numThreads!; const simd = flags.simd!; const useThreads = isMultiThreadSupported(numThreads); const useSimd = simd && isSimdSupported(); const wasmPaths = flags.wasmPaths; const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined; const wasmFileName = getWasmFileName(useSimd, useThreads); const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined; let isTimeout = false; const tasks: Array<Promise<void>> = []; // promise for timeout if (timeout > 0) { tasks.push(new Promise((resolve) => { setTimeout(() => { isTimeout = true; resolve(); }, timeout); })); } // promise for module initialization tasks.push(new Promise((resolve, reject) => { const factory = useThreads ? ortWasmFactoryThreaded : ortWasmFactory; const config: Partial<OrtWasmModule> = { locateFile: (fileName: string, scriptDirectory: string) => { if (!BUILD_DEFS.DISABLE_WASM_THREAD && useThreads && fileName.endsWith('.worker.js') && typeof Blob !== 'undefined') { return URL.createObjectURL(new Blob( [ // This require() function is handled by esbuild plugin to load file content as string. // eslint-disable-next-line @typescript-eslint/no-require-imports require('./binding/ort-wasm-threaded.worker.js') ], {type: 'text/javascript'})); } if (fileName.endsWith('.wasm')) { if (wasmPathOverride) { return wasmPathOverride; } const prefix = wasmPrefixOverride ?? scriptDirectory; if (!BUILD_DEFS.DISABLE_WEBGPU) { if (wasmFileName === 'ort-wasm-simd.wasm') { return prefix + 'ort-wasm-simd.jsep.wasm'; } else if (wasmFileName === 'ort-wasm-simd-threaded.wasm') { return prefix + 'ort-wasm-simd-threaded.jsep.wasm'; } } return prefix + wasmFileName; } return scriptDirectory + fileName; } }; if (!BUILD_DEFS.DISABLE_WASM_THREAD && useThreads) { config.numThreads = numThreads; if (typeof Blob === 'undefined') { config.mainScriptUrlOrBlob = path.join(__dirname, 'ort-wasm-threaded.js'); } else { const scriptSourceCode = `var ortWasmThreaded=${factory.toString()};`; config.mainScriptUrlOrBlob = new Blob([scriptSourceCode], {type: 'text/javascript'}); } } factory(config).then( // wasm module initialized successfully module => { initializing = false; initialized = true; wasm = module; resolve(); }, // wasm module failed to initialize (what) => { initializing = false; aborted = true; reject(what); }); })); await Promise.race(tasks); if (isTimeout) { throw new Error(`WebAssembly backend initializing failed due to timeout: ${timeout}ms`); } }; export const getInstance = (): OrtWasmModule => { if (initialized && wasm) { return wasm; } throw new Error('WebAssembly is not initialized yet.'); }; export const dispose = (): void => { if (initialized && !initializing && !aborted) { initializing = true; (wasm as OrtWasmThreadedModule).PThread?.terminateAllThreads(); wasm = undefined; initializing = false; initialized = false; aborted = true; } };