onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
267 lines (247 loc) • 9.31 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { env, InferenceSession } from 'onnxruntime-common';
import {
OrtWasmMessage,
SerializableInternalBuffer,
SerializableSessionMetadata,
SerializableTensorMetadata,
TensorMetadata,
} from './proxy-messages';
import * as core from './wasm-core-impl';
import { initializeWebAssembly } from './wasm-factory';
import {
importProxyWorker,
inferWasmPathPrefixFromScriptSrc,
isEsmImportMetaUrlHardcodedAsFileUri,
} from './wasm-utils-import';
const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined';
let proxyWorker: Worker | undefined;
let initializing = false;
let initialized = false;
let aborted = false;
let temporaryObjectUrl: string | undefined;
type PromiseCallbacks<T = void> = [resolve: (result: T) => void, reject: (reason: unknown) => void];
let initWasmCallbacks: PromiseCallbacks;
const queuedCallbacks: Map<OrtWasmMessage['type'], Array<PromiseCallbacks<unknown>>> = new Map();
const enqueueCallbacks = (type: OrtWasmMessage['type'], callbacks: PromiseCallbacks<unknown>): void => {
const queue = queuedCallbacks.get(type);
if (queue) {
queue.push(callbacks);
} else {
queuedCallbacks.set(type, [callbacks]);
}
};
const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
throw new Error('worker not ready');
}
};
const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
switch (ev.data.type) {
case 'init-wasm':
initializing = false;
if (ev.data.err) {
aborted = true;
initWasmCallbacks[1](ev.data.err);
} else {
initialized = true;
initWasmCallbacks[0]();
}
if (temporaryObjectUrl) {
URL.revokeObjectURL(temporaryObjectUrl);
temporaryObjectUrl = undefined;
}
break;
case 'init-ep':
case 'copy-from':
case 'create':
case 'release':
case 'run':
case 'end-profiling': {
const callbacks = queuedCallbacks.get(ev.data.type)!;
if (ev.data.err) {
callbacks.shift();
} else {
callbacks.shift();
}
break;
}
default:
}
};
export const initializeWebAssemblyAndOrtRuntime = async (): Promise<void> => {
if (initialized) {
return;
}
if (initializing) {
throw new Error("multiple calls to 'initWasm()' detected.");
}
if (aborted) {
throw new Error("previous call to 'initWasm()' failed.");
}
initializing = true;
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
return new Promise<void>((resolve, reject) => {
proxyWorker?.terminate();
void importProxyWorker().then(([objectUrl, worker]) => {
try {
proxyWorker = worker;
proxyWorker.onerror = (ev: ErrorEvent) => reject(ev);
proxyWorker.onmessage = onProxyWorkerMessage;
initWasmCallbacks = [resolve, reject];
const message: OrtWasmMessage = { type: 'init-wasm', in: env };
// if the proxy worker is loaded from a blob URL, we need to make sure the path information is not lost.
//
// when `env.wasm.wasmPaths` is not set, we need to pass the path information to the worker.
//
if (!BUILD_DEFS.ENABLE_BUNDLE_WASM_JS && !message.in!.wasm.wasmPaths && objectUrl) {
// for a build not bundled the wasm JS, we need to pass the path prefix to the worker.
// the path prefix will be used to resolve the path to both the wasm JS and the wasm file.
const inferredWasmPathPrefix = inferWasmPathPrefixFromScriptSrc();
if (inferredWasmPathPrefix) {
message.in!.wasm.wasmPaths = inferredWasmPathPrefix;
}
}
if (
BUILD_DEFS.IS_ESM &&
BUILD_DEFS.ENABLE_BUNDLE_WASM_JS &&
!message.in!.wasm.wasmPaths &&
(objectUrl || isEsmImportMetaUrlHardcodedAsFileUri)
) {
// for a build bundled the wasm JS, if either of the following conditions is met:
// - the proxy worker is loaded from a blob URL
// - `import.meta.url` is a file URL, it means it is overwriten by the bundler.
//
// in either case, the path information is lost, we need to pass the path of the .wasm file to the worker.
// we need to use the bundler preferred URL format:
// new URL('filename', import.meta.url)
// so that the bundler can handle the file using corresponding loaders.
message.in!.wasm.wasmPaths = {
wasm: !BUILD_DEFS.DISABLE_JSEP
? new URL('ort-wasm-simd-threaded.jsep.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href
: new URL('ort-wasm-simd-threaded.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href,
};
}
proxyWorker.postMessage(message);
temporaryObjectUrl = objectUrl;
} catch (e) {
reject(e);
}
}, reject);
});
} else {
try {
await initializeWebAssembly(env.wasm);
await core.initRuntime(env);
initialized = true;
} catch (e) {
aborted = true;
throw e;
} finally {
initializing = false;
}
}
};
export const initializeOrtEp = async (epName: string): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
enqueueCallbacks('init-ep', [resolve, reject]);
const message: OrtWasmMessage = { type: 'init-ep', in: { epName, env } };
proxyWorker!.postMessage(message);
});
} else {
await core.initEp(env, epName);
}
};
export const copyFromExternalBuffer = async (buffer: Uint8Array): Promise<SerializableInternalBuffer> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<SerializableInternalBuffer>((resolve, reject) => {
enqueueCallbacks('copy-from', [resolve, reject]);
const message: OrtWasmMessage = { type: 'copy-from', in: { buffer } };
proxyWorker!.postMessage(message, [buffer.buffer]);
});
} else {
return core.copyFromExternalBuffer(buffer);
}
};
export const createSession = async (
model: SerializableInternalBuffer | Uint8Array,
options?: InferenceSession.SessionOptions,
): Promise<SerializableSessionMetadata> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
// check unsupported options
if (options?.preferredOutputLocation) {
throw new Error('session option "preferredOutputLocation" is not supported for proxy.');
}
ensureWorker();
return new Promise<SerializableSessionMetadata>((resolve, reject) => {
enqueueCallbacks('create', [resolve, reject]);
const message: OrtWasmMessage = { type: 'create', in: { model, options: { ...options } } };
const transferable: Transferable[] = [];
if (model instanceof Uint8Array) {
transferable.push(model.buffer);
}
proxyWorker!.postMessage(message, transferable);
});
} else {
return core.createSession(model, options);
}
};
export const releaseSession = async (sessionId: number): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
enqueueCallbacks('release', [resolve, reject]);
const message: OrtWasmMessage = { type: 'release', in: sessionId };
proxyWorker!.postMessage(message);
});
} else {
core.releaseSession(sessionId);
}
};
export const run = async (
sessionId: number,
inputIndices: number[],
inputs: TensorMetadata[],
outputIndices: number[],
outputs: Array<TensorMetadata | null>,
options: InferenceSession.RunOptions,
): Promise<TensorMetadata[]> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
// check inputs location
if (inputs.some((t) => t[3] !== 'cpu')) {
throw new Error('input tensor on GPU is not supported for proxy.');
}
// check outputs location
if (outputs.some((t) => t)) {
throw new Error('pre-allocated output tensor is not supported for proxy.');
}
ensureWorker();
return new Promise<SerializableTensorMetadata[]>((resolve, reject) => {
enqueueCallbacks('run', [resolve, reject]);
const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU.
const message: OrtWasmMessage = {
type: 'run',
in: { sessionId, inputIndices, inputs: serializableInputs, outputIndices, options },
};
proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs));
});
} else {
return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options);
}
};
export const endProfiling = async (sessionId: number): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
enqueueCallbacks('end-profiling', [resolve, reject]);
const message: OrtWasmMessage = { type: 'end-profiling', in: sessionId };
proxyWorker!.postMessage(message);
});
} else {
core.endProfiling(sessionId);
}
};