onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
1,089 lines (991 loc) • 41.2 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
// WebNN API specification.
// https://github.com/webmachinelearning/webnn/issues/677
/// <reference path="jsep/webnn/webnn.d.ts" />
import { Env, InferenceSession, Tensor } from 'onnxruntime-common';
import {
SerializableInternalBuffer,
SerializableSessionMetadata,
SerializableTensorMetadata,
TensorMetadata,
} from './proxy-messages';
import { setRunOptions } from './run-options';
import { setSessionOptions } from './session-options';
import {
calculateTensorSizeInBytes,
dataLocationStringToEnum,
isGpuBufferSupportedType,
isMLTensorSupportedType,
logLevelStringToEnum,
tensorDataTypeEnumToString,
tensorDataTypeStringToEnum,
tensorTypeToTypedArrayConstructor,
} from './wasm-common';
import { getInstance } from './wasm-factory';
import { allocWasmString, checkLastError } from './wasm-utils';
import { loadFile } from './wasm-utils-load-file';
// #region Initializations
/**
* There are 4 different "initialization" steps for ORT. They happen in different places and different time.
*
* 1. JavaScript initialization for onnxruntime-common and onnxruntime-web.
* This is the first initialization step. In this step, onnxruntime-web calls onnxruntime-common's registerBackend()
* function multiple times to register all the available backends. The backend registration is very fast. It only
* registers the backend name with the uninitialized backend object. No heavy initialization is done in this step.
* Refer to web/lib/index.ts for the backend registration.
*
* 2. WebAssembly artifact initialization.
* This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is
* called). In this step, onnxruntime-web does the followings:
* - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled.
* - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated
* JavaScript code to initialize the WebAssembly runtime.
* - if proxy is enabled, this step happens in the proxy worker using message 'init-wasm'.
* - downloading the 'ort-wasm{...}.wasm' file is done in this step.
* - if multi-thread is enabled, one or more webworker will be created to initialize the PThread threadpool.
*
* 3. ORT environment initialization.
* This happens after step 2. In this step, onnxruntime-web performs ONNX Runtime environment initialization.
* Function `_OrtInit()` is called in this step.
* - if proxy is enabled, this step happens in the proxy worker using message 'init-ort'.
* - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step.
*
* 4. Session initialization.
* This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once),
* this step will be done for each session. In this step, onnxruntime-web does the followings:
* If the parameter is a URL:
* - download the model data from the URL.
* - copy the model data to the WASM heap. (proxy: 'copy-from')
* - dereference the model buffer. This step allows the original ArrayBuffer to be garbage collected.
* - call `_OrtCreateSession()` to create the session. (proxy: 'create')
*
* If the parameter is a Uint8Array object:
* - copy the model data to the WASM heap. (proxy: 'copy-from')
* - call `_OrtCreateSession()` to create the session. (proxy: 'create')
*
*
*/
/**
* initialize ORT environment.
*
* @param numThreads SetGlobalIntraOpNumThreads(numThreads)
* @param loggingLevel CreateEnv(static_cast<OrtLoggingLevel>(logging_level))
*/
const initOrt = (numThreads: number, loggingLevel: number): void => {
const errorCode = getInstance()._OrtInit(numThreads, loggingLevel);
if (errorCode !== 0) {
checkLastError("Can't initialize onnxruntime.");
}
};
/**
* initialize runtime environment.
* @param env passed in the environment config object.
*/
export const initRuntime = async (env: Env): Promise<void> => {
// init ORT
initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel));
};
/**
* perform EP specific initialization.
*
* @param env
* @param epName
*/
export const initEp = async (env: Env, epName: string): Promise<void> => {
// initialize ASYNCIFY support
getInstance().asyncInit?.();
if (epName === 'webgpu' && BUILD_DEFS.USE_WEBGPU_EP) {
getInstance().webgpuInit!((device) => {
env.webgpu.device = device;
});
}
if (!BUILD_DEFS.DISABLE_JSEP) {
// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
const initJsep = require('./jsep/init').init;
if (epName === 'webgpu' && !BUILD_DEFS.USE_WEBGPU_EP) {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
}
let adapter = env.webgpu.adapter as GPUAdapter | null;
if (!adapter) {
// if adapter is not set, request a new adapter.
const powerPreference = env.webgpu.powerPreference;
if (
powerPreference !== undefined &&
powerPreference !== 'low-power' &&
powerPreference !== 'high-performance'
) {
throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
}
const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;
if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
}
adapter = await navigator.gpu.requestAdapter({ powerPreference, forceFallbackAdapter });
if (!adapter) {
throw new Error(
'Failed to get GPU adapter. ' +
'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.',
);
}
} else {
// if adapter is set, validate it.
if (
typeof adapter.limits !== 'object' ||
typeof adapter.features !== 'object' ||
typeof adapter.requestDevice !== 'function'
) {
throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');
}
}
await initJsep('webgpu', getInstance(), env, adapter);
}
if (epName === 'webnn') {
// perform WebNN availability check
if (typeof navigator === 'undefined' || !(navigator as unknown as { ml: unknown }).ml) {
throw new Error('WebNN is not supported in current environment');
}
await initJsep('webnn', getInstance(), env);
}
}
};
// #endregion Initializations
/**
* valid data locations for input/output tensors.
*/
type SupportedTensorDataLocationForInputOutput =
| 'cpu'
| 'cpu-pinned'
| 'gpu-buffer'
| 'ml-tensor'
// Use 'ml-tensor' during inference, but output a tensor located on the CPU.
| 'ml-tensor-cpu-output';
type IOBindingState = {
/**
* the handle of IO binding.
*/
readonly handle: number;
/**
* the preferred location for each output tensor.
*
* value is one of 'cpu', 'cpu-pinned', 'gpu-buffer', 'ml-tensor'.
*/
readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[];
/**
* enum value of the preferred location for each output tensor.
*/
readonly outputPreferredLocationsEncoded: readonly number[];
};
/**
* tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState
*/
type SessionMetadata = [
inferenceSessionId: number,
inputNamesUTF8Encoded: number[],
outputNamesUTF8Encoded: number[],
bindingState: IOBindingState | null,
enableGraphCapture: boolean,
inputOutputBound: boolean,
];
const activeSessions = new Map<number, SessionMetadata>();
/**
* get the input/output count of the session.
* @param sessionHandle the handle representing the session. should be non-zero.
* @returns a tuple including 2 numbers, representing the input count and output count.
*/
const getSessionInputOutputCount = (sessionHandle: number): [number, number] => {
const wasm = getInstance();
const stack = wasm.stackSave();
try {
const ptrSize = wasm.PTR_SIZE;
const dataOffset = wasm.stackAlloc(2 * ptrSize);
const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + ptrSize);
if (errorCode !== 0) {
checkLastError("Can't get session input/output count.");
}
const type = ptrSize === 4 ? 'i32' : 'i64';
return [Number(wasm.getValue(dataOffset, type)), Number(wasm.getValue(dataOffset + ptrSize, type))];
} finally {
wasm.stackRestore(stack);
}
};
const getSessionInputOutputMetadata = (
sessionHandle: number,
index: number,
): [nameOffset: number, elementType: number, dims?: Array<number | string>] => {
const wasm = getInstance();
const stack = wasm.stackSave();
let metadataOffset = 0;
try {
const ptrSize = wasm.PTR_SIZE;
const dataOffset = wasm.stackAlloc(2 * ptrSize);
const errorCode = wasm._OrtGetInputOutputMetadata(sessionHandle, index, dataOffset, dataOffset + ptrSize);
if (errorCode !== 0) {
checkLastError("Can't get session input/output metadata.");
}
const nameOffset = Number(wasm.getValue(dataOffset, '*'));
metadataOffset = Number(wasm.getValue(dataOffset + ptrSize, '*'));
// get element type
const elementType = wasm.HEAP32[metadataOffset / 4];
if (elementType === 0) {
return [nameOffset, 0]; // non-tensor
}
// get dims count
const dimsCount = wasm.HEAPU32[metadataOffset / 4 + 1];
// get dims
const dims: Array<number | string> = [];
for (let i = 0; i < dimsCount; i++) {
const symbolicDimNameOffset = Number(wasm.getValue(metadataOffset + 8 + i * ptrSize, '*'));
dims.push(
symbolicDimNameOffset !== 0
? wasm.UTF8ToString(symbolicDimNameOffset)
: Number(wasm.getValue(metadataOffset + 8 + (i + dimsCount) * ptrSize, '*')),
);
}
return [nameOffset, elementType, dims];
} finally {
wasm.stackRestore(stack);
if (metadataOffset !== 0) {
wasm._OrtFree(metadataOffset);
}
}
};
/**
* allocate the memory and memcpy the external buffer.
*
* @param model - the external buffer containing the model data. Must not be the same buffer as the WASM heap.
* @returns a 2-elements tuple - the pointer and size of the allocated buffer
*/
export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => {
const wasm = getInstance();
const modelDataOffset = wasm._malloc(model.byteLength);
if (modelDataOffset === 0) {
throw new Error(`Can't create a session. failed to allocate a buffer of size ${model.byteLength}.`);
}
wasm.HEAPU8.set(model, modelDataOffset);
return [modelDataOffset, model.byteLength];
};
/**
* create an inference session from a model data buffer.
*
* @param modelData - either a Uint8Array object representing the model data, or a 2-elements tuple containing the
* pointer and size of the model data buffer.
* @param options an optional session options object.
* @returns a 3-elements tuple containing [session handle, input names, output names]
*/
export const createSession = async (
modelData: Uint8Array | SerializableInternalBuffer,
options?: InferenceSession.SessionOptions,
): Promise<SerializableSessionMetadata> => {
let modelDataOffset: number, modelDataLength: number;
const wasm = getInstance();
if (Array.isArray(modelData)) {
// if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
[modelDataOffset, modelDataLength] = modelData;
} else if (modelData.buffer === wasm.HEAPU8.buffer) {
// if model data uses the same buffer as the WASM heap, we don't need to copy it.
[modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
} else {
// otherwise, copy the model data to the WASM heap.
[modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
}
let sessionHandle = 0;
let sessionOptionsHandle = 0;
let ioBindingHandle = 0;
let allocs: number[] = [];
const inputNamesUTF8Encoded = [];
const outputNamesUTF8Encoded = [];
try {
[sessionOptionsHandle, allocs] = await setSessionOptions(options);
if (options?.externalData && wasm.mountExternalData) {
const loadingPromises = [];
for (const file of options.externalData) {
const path = typeof file === 'string' ? file : file.path;
loadingPromises.push(
loadFile(typeof file === 'string' ? file : file.data).then((data) => {
wasm.mountExternalData(path, data);
}),
);
}
// wait for all external data files to be loaded
await Promise.all(loadingPromises);
}
for (const provider of options?.executionProviders ?? []) {
const providerName = typeof provider === 'string' ? provider : provider.name;
if (providerName === 'webnn') {
wasm.shouldTransferToMLTensor = false;
if (typeof provider !== 'string') {
const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption;
const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice;
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
if (context) {
wasm.currentContext = context as MLContext;
} else if (gpuDevice) {
wasm.currentContext = await wasm.webnnCreateMLContext!(gpuDevice);
} else {
wasm.currentContext = await wasm.webnnCreateMLContext!({ deviceType, powerPreference });
}
} else {
wasm.currentContext = await wasm.webnnCreateMLContext!();
}
break;
}
}
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
wasm.webgpuOnCreateSession?.(sessionHandle);
if (sessionHandle === 0) {
checkLastError("Can't create a session.");
}
wasm.jsepOnCreateSession?.();
// clear current MLContext after session creation
if (wasm.currentContext) {
wasm.webnnRegisterMLContext!(sessionHandle, wasm.currentContext);
wasm.currentContext = undefined;
wasm.shouldTransferToMLTensor = true;
}
const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
const enableGraphCapture = !!options?.enableGraphCapture;
const inputNames = [];
const outputNames = [];
const inputMetadata: InferenceSession.ValueMetadata[] = [];
const outputMetadata: InferenceSession.ValueMetadata[] = [];
const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
for (let i = 0; i < inputCount; i++) {
const [nameOffset, elementType, shape] = getSessionInputOutputMetadata(sessionHandle, i);
if (nameOffset === 0) {
checkLastError("Can't get an input name.");
}
inputNamesUTF8Encoded.push(nameOffset);
const name = wasm.UTF8ToString(nameOffset);
inputNames.push(name);
inputMetadata.push(
elementType === 0
? { name, isTensor: false }
: { name, isTensor: true, type: tensorDataTypeEnumToString(elementType), shape: shape! },
);
}
for (let i = 0; i < outputCount; i++) {
const [nameOffset, elementType, shape] = getSessionInputOutputMetadata(sessionHandle, i + inputCount);
if (nameOffset === 0) {
checkLastError("Can't get an output name.");
}
outputNamesUTF8Encoded.push(nameOffset);
const nameString = wasm.UTF8ToString(nameOffset);
outputNames.push(nameString);
outputMetadata.push(
elementType === 0
? { name: nameString, isTensor: false }
: { name: nameString, isTensor: true, type: tensorDataTypeEnumToString(elementType), shape: shape! },
);
if (!BUILD_DEFS.DISABLE_JSEP) {
if (enableGraphCapture && options?.preferredOutputLocation === undefined) {
outputPreferredLocations.push('gpu-buffer');
continue;
}
const location =
typeof options?.preferredOutputLocation === 'string'
? options.preferredOutputLocation
: (options?.preferredOutputLocation?.[nameString] ?? 'cpu');
const isGraphOutput = wasm.webnnIsGraphOutput;
if (location === 'cpu' && isGraphOutput && isGraphOutput(sessionHandle, nameString)) {
outputPreferredLocations.push('ml-tensor-cpu-output');
continue;
}
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-tensor') {
throw new Error(`Not supported preferred output location: ${location}.`);
}
if (enableGraphCapture && location !== 'gpu-buffer') {
throw new Error(
`Not supported preferred output location: ${location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`,
);
}
outputPreferredLocations.push(location);
}
}
// use IO binding only when at least one output is preferred to be on GPU.
let bindingState: IOBindingState | null = null;
if (
!BUILD_DEFS.DISABLE_JSEP &&
outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor' || l === 'ml-tensor-cpu-output')
) {
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
if (ioBindingHandle === 0) {
checkLastError("Can't create IO binding.");
}
bindingState = {
handle: ioBindingHandle,
outputPreferredLocations,
outputPreferredLocationsEncoded: outputPreferredLocations
// 'ml-tensor-cpu-output' is treated as 'ml-tensor' for the purpose of IO binding.
.map((l) => (l === 'ml-tensor-cpu-output' ? 'ml-tensor' : l))
.map((l) => dataLocationStringToEnum(l)),
};
}
activeSessions.set(sessionHandle, [
sessionHandle,
inputNamesUTF8Encoded,
outputNamesUTF8Encoded,
bindingState,
enableGraphCapture,
false,
]);
return [sessionHandle, inputNames, outputNames, inputMetadata, outputMetadata];
} catch (e) {
inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
if (ioBindingHandle !== 0) {
if (wasm._OrtReleaseBinding(ioBindingHandle) !== 0) {
checkLastError("Can't release IO binding.");
}
}
if (sessionHandle !== 0) {
if (wasm._OrtReleaseSession(sessionHandle) !== 0) {
checkLastError("Can't release session.");
}
}
throw e;
} finally {
wasm._free(modelDataOffset);
if (sessionOptionsHandle !== 0) {
if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) {
checkLastError("Can't release session options.");
}
}
allocs.forEach((alloc) => wasm._free(alloc));
// unmount external data if necessary
wasm.unmountExternalData?.();
}
};
export const releaseSession = (sessionId: number): void => {
const wasm = getInstance();
const session = activeSessions.get(sessionId);
if (!session) {
throw new Error(`cannot release session. invalid session id: ${sessionId}`);
}
const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session;
if (ioBindingState) {
if (enableGraphCapture) {
if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) {
checkLastError("Can't clear bound outputs.");
}
}
if (wasm._OrtReleaseBinding(ioBindingState.handle) !== 0) {
checkLastError("Can't release IO binding.");
}
}
wasm.jsepOnReleaseSession?.(sessionId);
wasm.webnnOnReleaseSession?.(sessionId);
wasm.webgpuOnReleaseSession?.(sessionId);
inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
if (wasm._OrtReleaseSession(sessionHandle) !== 0) {
checkLastError("Can't release session.");
}
activeSessions.delete(sessionId);
};
export const prepareInputOutputTensor = async (
tensor: TensorMetadata | null,
tensorHandles: number[],
allocs: number[],
sessionId: number,
tensorNameUTF8Encoded: number,
index: number,
enableGraphCapture = false,
): Promise<void> => {
if (!tensor) {
tensorHandles.push(0);
return;
}
const wasm = getInstance();
const ptrSize = wasm.PTR_SIZE;
const dataType = tensor[0];
const dims = tensor[1];
const location = tensor[3];
let actualLocation = location;
let rawData: number;
let dataByteLength: number;
if (dataType === 'string' && (location === 'gpu-buffer' || location === 'ml-tensor')) {
throw new Error('String tensor is not supported on GPU.');
}
if (enableGraphCapture && location !== 'gpu-buffer') {
throw new Error(
`External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`,
);
}
if (location === 'gpu-buffer') {
const gpuBuffer = tensor[2].gpuBuffer;
dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
if (BUILD_DEFS.USE_WEBGPU_EP) {
const registerBuffer = wasm.webgpuRegisterBuffer;
if (!registerBuffer) {
throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
}
rawData = registerBuffer(gpuBuffer, sessionId);
} else {
const registerBuffer = wasm.jsepRegisterBuffer;
if (!registerBuffer) {
throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
}
rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength);
}
} else if (location === 'ml-tensor') {
const mlTensor = tensor[2].mlTensor as MLTensor;
dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
const registerMLTensor = wasm.webnnRegisterMLTensor;
if (!registerMLTensor) {
throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.');
}
rawData = registerMLTensor(sessionId, mlTensor, tensorDataTypeStringToEnum(dataType), dims);
} else {
const data = tensor[2];
if (Array.isArray(data)) {
// string tensor
dataByteLength = ptrSize * data.length;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
for (let i = 0; i < data.length; i++) {
if (typeof data[i] !== 'string') {
throw new TypeError(`tensor data at index ${i} is not a string`);
}
wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*');
}
} else {
const isGraphInput = wasm.webnnIsGraphInput;
const isGraphOutput = wasm.webnnIsGraphOutput;
if (dataType !== 'string' && isGraphInput && isGraphOutput) {
const tensorName = wasm.UTF8ToString(tensorNameUTF8Encoded);
// Promote the tensor to 'ml-tensor' if it is a graph input.
if (isGraphInput(sessionId, tensorName) || isGraphOutput(sessionId, tensorName)) {
const dataTypeEnum = tensorDataTypeStringToEnum(dataType);
dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!;
actualLocation = 'ml-tensor';
const createTemporaryTensor = wasm.webnnCreateTemporaryTensor;
const uploadTensor = wasm.webnnUploadTensor;
if (!createTemporaryTensor || !uploadTensor) {
throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.');
}
const tensorId = await createTemporaryTensor(sessionId, dataTypeEnum, dims as number[]);
uploadTensor(tensorId, new Uint8Array(data.buffer, data.byteOffset, data.byteLength));
rawData = tensorId;
} else {
dataByteLength = data.byteLength;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
}
} else {
dataByteLength = data.byteLength;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
}
}
}
const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
dims.forEach((d, index) => wasm.setValue(dimsOffset + index * ptrSize, d, ptrSize === 4 ? 'i32' : 'i64'));
const tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(dataType),
rawData,
dataByteLength,
dimsOffset,
dims.length,
dataLocationStringToEnum(actualLocation),
);
if (tensor === 0) {
checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
}
tensorHandles.push(tensor);
} finally {
wasm.stackRestore(stack);
}
};
/**
* perform inference run
*/
export const run = async (
sessionId: number,
inputIndices: number[],
inputTensors: TensorMetadata[],
outputIndices: number[],
outputTensors: Array<TensorMetadata | null>,
options: InferenceSession.RunOptions,
): Promise<TensorMetadata[]> => {
const wasm = getInstance();
const ptrSize = wasm.PTR_SIZE;
const session = activeSessions.get(sessionId);
if (!session) {
throw new Error(`cannot run inference. invalid session id: ${sessionId}`);
}
const sessionHandle = session[0];
const inputNamesUTF8Encoded = session[1];
const outputNamesUTF8Encoded = session[2];
const ioBindingState = session[3];
const enableGraphCapture = session[4];
const inputOutputBound = session[5];
const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
let runOptionsHandle = 0;
let runOptionsAllocs: number[] = [];
const inputTensorHandles: number[] = [];
const outputTensorHandles: number[] = [];
const inputOutputAllocs: number[] = [];
const beforeRunStack = wasm.stackSave();
const inputValuesOffset = wasm.stackAlloc(inputCount * ptrSize);
const inputNamesOffset = wasm.stackAlloc(inputCount * ptrSize);
const outputValuesOffset = wasm.stackAlloc(outputCount * ptrSize);
const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize);
try {
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
// create input tensors
for (let i = 0; i < inputCount; i++) {
await prepareInputOutputTensor(
inputTensors[i],
inputTensorHandles,
inputOutputAllocs,
sessionId,
inputNamesUTF8Encoded[inputIndices[i]],
inputIndices[i],
enableGraphCapture,
);
}
// create output tensors
for (let i = 0; i < outputCount; i++) {
await prepareInputOutputTensor(
outputTensors[i],
outputTensorHandles,
inputOutputAllocs,
sessionId,
outputNamesUTF8Encoded[outputIndices[i]],
inputCount + outputIndices[i],
enableGraphCapture,
);
}
for (let i = 0; i < inputCount; i++) {
wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], '*');
wasm.setValue(inputNamesOffset + i * ptrSize, inputNamesUTF8Encoded[inputIndices[i]], '*');
}
for (let i = 0; i < outputCount; i++) {
wasm.setValue(outputValuesOffset + i * ptrSize, outputTensorHandles[i], '*');
wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], '*');
}
if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) {
const { handle, outputPreferredLocations, outputPreferredLocationsEncoded } = ioBindingState;
if (inputNamesUTF8Encoded.length !== inputCount) {
throw new Error(
`input count from feeds (${inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`,
);
}
// process inputs
for (let i = 0; i < inputCount; i++) {
const index = inputIndices[i];
const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]);
if (errorCode !== 0) {
checkLastError(`Can't bind input[${i}] for session=${sessionId}.`);
}
}
// process pre-allocated outputs
for (let i = 0; i < outputCount; i++) {
const index = outputIndices[i];
const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated.
if (location) {
// output is pre-allocated. bind the tensor.
const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0);
if (errorCode !== 0) {
checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`);
}
} else {
// output is not pre-allocated. reset preferred location.
const errorCode = wasm._OrtBindOutput(
handle,
outputNamesUTF8Encoded[index],
0,
outputPreferredLocationsEncoded[index],
);
if (errorCode !== 0) {
checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`);
}
}
}
activeSessions.set(sessionId, [
sessionHandle,
inputNamesUTF8Encoded,
outputNamesUTF8Encoded,
ioBindingState,
enableGraphCapture,
true,
]);
}
wasm.jsepOnRunStart?.(sessionHandle);
wasm.webnnOnRunStart?.(sessionHandle);
let errorCode: number;
if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) {
errorCode = await wasm._OrtRunWithBinding(
sessionHandle,
ioBindingState.handle,
outputCount,
outputValuesOffset,
runOptionsHandle,
);
} else {
errorCode = await wasm._OrtRun(
sessionHandle,
inputNamesOffset,
inputValuesOffset,
inputCount,
outputNamesOffset,
outputCount,
outputValuesOffset,
runOptionsHandle,
);
}
if (errorCode !== 0) {
checkLastError('failed to call OrtRun().');
}
const output: TensorMetadata[] = [];
const outputPromises: Array<Promise<[number, Tensor.DataType]>> = [];
for (let i = 0; i < outputCount; i++) {
const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*'));
if (tensor === outputTensorHandles[i]) {
// output tensor is pre-allocated. no need to copy data.
output.push(outputTensors[i]!);
continue;
}
const beforeGetTensorDataStack = wasm.stackSave();
// stack allocate 4 pointer value
const tensorDataOffset = wasm.stackAlloc(4 * ptrSize);
let keepOutputTensor = false;
let type: Tensor.Type | undefined,
dataOffset = 0;
try {
const errorCode = wasm._OrtGetTensorData(
tensor,
tensorDataOffset,
tensorDataOffset + ptrSize,
tensorDataOffset + 2 * ptrSize,
tensorDataOffset + 3 * ptrSize,
);
if (errorCode !== 0) {
checkLastError(`Can't access output tensor data on index ${i}.`);
}
const valueType = ptrSize === 4 ? 'i32' : 'i64';
const dataType = Number(wasm.getValue(tensorDataOffset, valueType));
dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*');
const dimsOffset = wasm.getValue(tensorDataOffset + ptrSize * 2, '*');
const dimsLength = Number(wasm.getValue(tensorDataOffset + ptrSize * 3, valueType));
const dims = [];
for (let i = 0; i < dimsLength; i++) {
dims.push(Number(wasm.getValue(dimsOffset + i * ptrSize, valueType)));
}
if (wasm._OrtFree(dimsOffset) !== 0) {
checkLastError("Can't free memory for tensor dims.");
}
const size = dims.reduce((a, b) => a * b, 1);
type = tensorDataTypeEnumToString(dataType);
const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]];
if (type === 'string') {
if (preferredLocation === 'gpu-buffer' || preferredLocation === 'ml-tensor') {
throw new Error('String tensor is not supported on GPU.');
}
const stringData: string[] = [];
for (let i = 0; i < size; i++) {
const offset = wasm.getValue(dataOffset + i * ptrSize, '*');
const nextOffset = wasm.getValue(dataOffset + (i + 1) * ptrSize, '*');
const maxBytesToRead = i === size - 1 ? undefined : nextOffset - offset;
stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
}
output.push([type, dims, stringData, 'cpu']);
} else {
// If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU
// tensor for it. There is no mapping GPU buffer for an empty tensor.
if (preferredLocation === 'gpu-buffer' && size > 0) {
const getBuffer = BUILD_DEFS.USE_WEBGPU_EP ? wasm.webgpuGetBuffer : wasm.jsepGetBuffer;
if (!getBuffer) {
throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');
}
const gpuBuffer = getBuffer(dataOffset);
const bufferSize = calculateTensorSizeInBytes(dataType, size);
if (bufferSize === undefined || !isGpuBufferSupportedType(type)) {
throw new Error(`Unsupported data type: ${type}`);
}
// do not release the tensor right now. it will be released when user calls tensor.dispose().
keepOutputTensor = true;
if (BUILD_DEFS.USE_WEBGPU_EP) {
wasm.webgpuRegisterBuffer!(gpuBuffer, sessionId, dataOffset);
const downloadDataFunction = wasm.webgpuCreateDownloader!(gpuBuffer, bufferSize, sessionId);
output.push([
type,
dims,
{
gpuBuffer,
download: async () => {
const arrayBuffer = await downloadDataFunction();
const data = new (tensorTypeToTypedArrayConstructor(type!))(arrayBuffer);
return data as Tensor.DataTypeMap[Tensor.GpuBufferDataTypes];
},
dispose: () => {
if (wasm._OrtReleaseTensor(tensor) !== 0) {
checkLastError("Can't release tensor.");
}
},
},
'gpu-buffer',
]);
} else {
output.push([
type,
dims,
{
gpuBuffer,
download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type),
dispose: () => {
if (wasm._OrtReleaseTensor(tensor) !== 0) {
checkLastError("Can't release tensor.");
}
},
},
'gpu-buffer',
]);
}
} else if (preferredLocation === 'ml-tensor' && size > 0) {
const ensureTensor = wasm.webnnEnsureTensor;
const isGraphInputOutputTypeSupported = wasm.webnnIsGraphInputOutputTypeSupported;
if (!ensureTensor || !isGraphInputOutputTypeSupported) {
throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.');
}
const tensorSize = calculateTensorSizeInBytes(dataType, size);
if (tensorSize === undefined || !isMLTensorSupportedType(type)) {
throw new Error(`Unsupported data type: ${type}`);
}
if (!isGraphInputOutputTypeSupported(sessionId, type, false)) {
throw new Error(
`preferredLocation "ml-tensor" for ${type} output is not supported by current WebNN Context.`,
);
}
// If the graph has been partitioned, the output tensor may have not been created. For this reason, we use
// ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor
// has been created.
const mlTensor = await ensureTensor(sessionId, dataOffset, dataType, dims, false);
// do not release the tensor right now. it will be released when user calls tensor.dispose().
keepOutputTensor = true;
output.push([
type,
dims,
{
mlTensor,
download: wasm.webnnCreateMLTensorDownloader!(dataOffset, type),
dispose: () => {
wasm.webnnReleaseTensorId!(dataOffset);
wasm._OrtReleaseTensor(tensor);
},
},
'ml-tensor',
]);
} else if (preferredLocation === 'ml-tensor-cpu-output' && size > 0) {
const data = wasm.webnnCreateMLTensorDownloader!(dataOffset, type as Tensor.MLTensorDataTypes)();
const index = output.length;
// Delay the data download and releasing the tensor until we can wait for all output tensors to be downloaded.
keepOutputTensor = true;
outputPromises.push(
(async () => {
const result: [number, Tensor.DataType] = [index, await data];
wasm.webnnReleaseTensorId!(dataOffset);
wasm._OrtReleaseTensor(tensor);
return result;
})(),
);
output.push([type, dims, [], 'cpu']);
} else {
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
const data = new typedArrayConstructor(size);
new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),
);
output.push([type, dims, data, 'cpu']);
}
}
} finally {
wasm.stackRestore(beforeGetTensorDataStack);
if (type === 'string' && dataOffset) {
wasm._free(dataOffset);
}
if (!keepOutputTensor) {
wasm._OrtReleaseTensor(tensor);
}
}
}
if (ioBindingState && !enableGraphCapture) {
if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) {
checkLastError("Can't clear bound outputs.");
}
activeSessions.set(sessionId, [
sessionHandle,
inputNamesUTF8Encoded,
outputNamesUTF8Encoded,
ioBindingState,
enableGraphCapture,
false,
]);
}
// Wait for all output tensor data to be downloaded.
for (const [index, data] of await Promise.all(outputPromises)) {
output[index][2] = data;
}
return output;
} finally {
wasm.webnnOnRunEnd?.(sessionHandle);
wasm.stackRestore(beforeRunStack);
if (BUILD_DEFS.USE_WEBGPU_EP) {
inputTensors.forEach((t) => {
if (t && t[3] === 'gpu-buffer') {
wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer);
}
});
outputTensors.forEach((t) => {
if (t && t[3] === 'gpu-buffer') {
wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer);
}
});
}
inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
inputOutputAllocs.forEach((p) => wasm._free(p));
if (runOptionsHandle !== 0) {
wasm._OrtReleaseRunOptions(runOptionsHandle);
}
runOptionsAllocs.forEach((p) => wasm._free(p));
}
};
/**
* end profiling
*/
export const endProfiling = (sessionId: number): void => {
const wasm = getInstance();
const session = activeSessions.get(sessionId);
if (!session) {
throw new Error('invalid session id');
}
const sessionHandle = session[0];
// profile file name is not used yet, but it must be freed.
const profileFileName = wasm._OrtEndProfiling(sessionHandle);
if (profileFileName === 0) {
checkLastError("Can't get an profile file name.");
}
wasm._OrtFree(profileFileName);
};
export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => {
const buffers: ArrayBufferLike[] = [];
for (const tensor of tensors) {
const data = tensor[2];
if (!Array.isArray(data) && 'buffer' in data) {
buffers.push(data.buffer);
}
}
return buffers;
};