UNPKG

@huggingface/transformers

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

211 lines (178 loc) 8.23 kB
/** * @file Handler file for choosing the correct version of ONNX Runtime, based on the environment. * Ideally, we could import the `onnxruntime-web` and `onnxruntime-node` packages only when needed, * but dynamic imports don't seem to work with the current webpack version and/or configuration. * This is possibly due to the experimental nature of top-level await statements. * So, we just import both packages, and use the appropriate one based on the environment: * - When running in node, we use `onnxruntime-node`. * - When running in the browser, we use `onnxruntime-web` (`onnxruntime-node` is not bundled). * * This module is not directly exported, but can be accessed through the environment variables: * ```javascript * import { env } from '@huggingface/transformers'; * console.log(env.backends.onnx); * ``` * * @module backends/onnx */ import { env, apis } from '../env.js'; // NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`. // In either case, we select the default export if it exists, otherwise we use the named export. import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web'; export { Tensor } from 'onnxruntime-common'; /** * @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders */ /** @type {Record<import("../utils/devices.js").DeviceType, ONNXExecutionProviders>} */ const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({ auto: null, // Auto-detect based on device and environment gpu: null, // Auto-detect GPU cpu: 'cpu', // CPU wasm: 'wasm', // WebAssembly webgpu: 'webgpu', // WebGPU cuda: 'cuda', // CUDA dml: 'dml', // DirectML webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default) 'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU 'webnn-gpu': { name: 'webnn', deviceType: 'gpu' }, // WebNN GPU 'webnn-cpu': { name: 'webnn', deviceType: 'cpu' }, // WebNN CPU }); /** * The list of supported devices, sorted by priority/performance. * @type {import("../utils/devices.js").DeviceType[]} */ const supportedDevices = []; /** @type {ONNXExecutionProviders[]} */ let defaultDevices; let ONNX; const ORT_SYMBOL = Symbol.for('onnxruntime'); if (ORT_SYMBOL in globalThis) { // If the JS runtime exposes their own ONNX runtime, use it ONNX = globalThis[ORT_SYMBOL]; } else if (apis.IS_NODE_ENV) { ONNX = ONNX_NODE.default ?? ONNX_NODE; // Updated as of ONNX Runtime 1.20.1 // The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries. // | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 | // | ------------- | ----------- | ------------- | ----------------- | ----------- | --------- | ----------- | // | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | // | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ | // | CUDA | ❌ | ❌ | ✔️ (CUDA v11.8) | ❌ | ❌ | ❌ | switch (process.platform) { case 'win32': // Windows x64 and Windows arm64 supportedDevices.push('dml'); break; case 'linux': // Linux x64 and Linux arm64 if (process.arch === 'x64') { supportedDevices.push('cuda'); } break; case 'darwin': // MacOS x64 and MacOS arm64 break; } supportedDevices.push('cpu'); defaultDevices = ['cpu']; } else { ONNX = ONNX_WEB; if (apis.IS_WEBNN_AVAILABLE) { // TODO: Only push supported providers (depending on available hardware) supportedDevices.push('webnn-npu', 'webnn-gpu', 'webnn-cpu', 'webnn'); } if (apis.IS_WEBGPU_AVAILABLE) { supportedDevices.push('webgpu'); } supportedDevices.push('wasm'); defaultDevices = ['wasm']; } // @ts-ignore const InferenceSession = ONNX.InferenceSession; /** * Map a device to the execution providers to use for the given device. * @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on. * @returns {ONNXExecutionProviders[]} The execution providers to use for the given device. */ export function deviceToExecutionProviders(device = null) { // Use the default execution providers if the user hasn't specified anything if (!device) return defaultDevices; // Handle overloaded cases switch (device) { case "auto": return supportedDevices; case "gpu": return supportedDevices.filter(x => ["webgpu", "cuda", "dml", "webnn-gpu"].includes(x), ); } if (supportedDevices.includes(device)) { return [DEVICE_TO_EXECUTION_PROVIDER_MAPPING[device] ?? device]; } throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedDevices.join(', ')}.`) } /** * To prevent multiple calls to `initWasm()`, we store the first call in a Promise * that is resolved when the first InferenceSession is created. Subsequent calls * will wait for this Promise to resolve before creating their own InferenceSession. * @type {Promise<any>|null} */ let wasmInitPromise = null; /** * Create an ONNX inference session. * @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path. * @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options. * @param {Object} session_config ONNX inference session configuration. * @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session. */ export async function createInferenceSession(buffer_or_path, session_options, session_config) { if (wasmInitPromise) { // A previous session has already initialized the WASM runtime // so we wait for it to resolve before creating this new session. await wasmInitPromise; } const sessionPromise = InferenceSession.create(buffer_or_path, session_options); wasmInitPromise ??= sessionPromise; const session = await sessionPromise; session.config = session_config; return session; } /** * Check if an object is an ONNX tensor. * @param {any} x The object to check * @returns {boolean} Whether the object is an ONNX tensor. */ export function isONNXTensor(x) { return x instanceof ONNX.Tensor; } /** @type {import('onnxruntime-common').Env} */ // @ts-ignore const ONNX_ENV = ONNX?.env; if (ONNX_ENV?.wasm) { // Initialize wasm backend with suitable default settings. // (Optional) Set path to wasm files. This will override the default path search behavior of onnxruntime-web. // By default, we only do this if we are not in a service worker and the wasmPaths are not already set. if ( // @ts-ignore Cannot find name 'ServiceWorkerGlobalScope'.ts(2304) !(typeof ServiceWorkerGlobalScope !== 'undefined' && self instanceof ServiceWorkerGlobalScope) && !ONNX_ENV.wasm.wasmPaths ) { ONNX_ENV.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${env.version}/dist/`; } // TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0 // https://github.com/microsoft/onnxruntime/pull/21534 // Users may wish to proxy the WASM backend to prevent the UI from freezing, // However, this is not necessary when using WebGPU, so we default to false. ONNX_ENV.wasm.proxy = false; } if (ONNX_ENV?.webgpu) { ONNX_ENV.webgpu.powerPreference = 'high-performance'; } /** * Check if ONNX's WASM backend is being proxied. * @returns {boolean} Whether ONNX's WASM backend is being proxied. */ export function isONNXProxy() { // TODO: Update this when allowing non-WASM backends. return ONNX_ENV?.wasm?.proxy; } // Expose ONNX environment variables to `env.backends.onnx` env.backends.onnx = ONNX_ENV;