onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
272 lines (242 loc) • 9.75 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { Env } from 'onnxruntime-common';
import { calculateTensorSizeInBytes, DataType } from '../wasm-common';
import type { OrtWasmModule } from '../wasm-types';
import { WebGpuBackend } from './backend-webgpu';
import { LOG_DEBUG } from './log';
import { TensorView } from './tensor-view';
import { ShapeUtil } from './util';
import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types';
/* eslint-disable no-bitwise */
class TensorViewImpl implements TensorView {
constructor(
private module: OrtWasmModule,
public readonly dataType: number,
public readonly data: number,
public readonly dims: readonly number[],
) {}
getFloat32Array(): Float32Array {
if (this.dataType !== DataType.float) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0
? new Float32Array()
: new Float32Array(this.module.HEAP8.buffer, this.data, elementCount);
}
getBigInt64Array(): BigInt64Array {
if (this.dataType !== DataType.int64) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0
? new BigInt64Array()
: new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount);
}
getInt32Array(): Int32Array {
if (this.dataType !== DataType.int32) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount);
}
getUint16Array(): Uint16Array {
if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
}
reshape(newDims: readonly number[]): TensorView {
if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) {
throw new Error('Invalid new shape');
}
return new TensorViewImpl(this.module, this.dataType, this.data, newDims);
}
}
class ComputeContextImpl implements ComputeContext {
readonly adapterInfo: AdapterInfo;
readonly opKernelContext: number;
readonly inputs: readonly TensorView[];
readonly outputCount: number;
get kernelCustomData(): { [key: string]: unknown } {
return this.backend.currentKernelCustomData;
}
get customDataBuffer(): Uint8Array {
return this.module.HEAPU8.subarray(this.customDataOffset, this.customDataOffset + this.customDataSize);
}
private customDataOffset = 0;
private customDataSize = 0;
constructor(
private module: OrtWasmModule,
private backend: WebGpuBackend,
contextDataOffset: number,
) {
this.adapterInfo = backend.adapterInfo;
const heapU32 = module.HEAPU32;
// extract context data
let dataIndex = contextDataOffset >>> 2;
this.opKernelContext = heapU32[dataIndex++];
const inputCount = heapU32[dataIndex++];
this.outputCount = heapU32[dataIndex++];
this.customDataOffset = heapU32[dataIndex++];
this.customDataSize = heapU32[dataIndex++];
const inputs: TensorView[] = [];
for (let i = 0; i < inputCount; i++) {
const dataType = heapU32[dataIndex++];
const data = heapU32[dataIndex++];
const dim = heapU32[dataIndex++];
const dims: number[] = [];
for (let d = 0; d < dim; d++) {
dims.push(heapU32[dataIndex++]);
}
inputs.push(new TensorViewImpl(module, dataType, data, dims));
}
this.inputs = inputs;
}
getMaxComputeWorkgroupSizes(): [number, number, number] {
return [
this.backend.device.limits.maxComputeWorkgroupSizeX,
this.backend.device.limits.maxComputeWorkgroupSizeY,
this.backend.device.limits.maxComputeWorkgroupSizeZ,
];
}
getMaxComputeWorkgroupStoragesize(): number {
return this.backend.device.limits.maxComputeWorkgroupStorageSize;
}
compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
// prepare inputs. inputs should always be valid data.
const mappedInputs =
inputsOutputsMapping?.inputs?.map((i) => (typeof i === 'number' ? this.inputs[i] : i)) ?? this.inputs;
// prepare outputs.
const outputIndices = inputsOutputsMapping?.outputs ?? [];
const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView =>
new TensorViewImpl(this.module, dataType, this.output(index, dims), dims);
const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => {
const bufferSize = calculateTensorSizeInBytes(dataType, dims);
if (!bufferSize) {
throw new Error(`Unsupported data type: ${dataType}`);
}
const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
};
return this.backend.run(
program,
mappedInputs,
outputIndices,
createKernelOutput,
createTemporaryOutput,
this.outputCount,
);
}
output(index: number, dims: readonly number[]): number {
const stack = this.module.stackSave();
try {
const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */);
let offset = data >> 2;
this.module.HEAPU32[offset++] = dims.length;
for (let i = 0; i < dims.length; i++) {
this.module.HEAPU32[offset++] = dims[i];
}
return this.module._JsepOutput!(this.opKernelContext, index, data);
} catch (e) {
throw new Error(
`Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' +
`Error: ${e}`,
);
} finally {
this.module.stackRestore(stack);
}
}
}
/**
* Initialize JSEP with WebGPU backend.
*
* This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for
* each of the following EPs if they are specified:
* - "webgpu"
* - "webnn"
*
* For WebGPU, this function expects:
* - WebGPU is enabled in build (BUILD_DEFS.DISABLE_JSEP === false).
* - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
*
* For WebNN, this function expects:
* - WebNN is enabled in build (BUILD_DEFS.DISABLE_JSEP === false).
* - WebNN is available in current environment. (navigator.ml is not undefined)
*
* If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
* 'webgpu'/'webnn' backend.
*
* @param name - the name of the EP, either "webgpu" or "webnn"
* @param module - the ORT WebAssembly module
* @param env - the ORT environment variable (ort.env)
* @param gpuAdapter - the pre-created GPU adapter
*/
export const init = async (
name: 'webgpu' | 'webnn',
module: OrtWasmModule,
env: Env,
gpuAdapter?: GPUAdapter,
): Promise<void> => {
const jsepInit = module.jsepInit;
if (!jsepInit) {
throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
}
if (name === 'webgpu') {
const backend = new WebGpuBackend();
await backend.initialize(env, gpuAdapter!);
jsepInit('webgpu', [
// backend
backend,
// jsepAlloc()
(size: number) => backend.alloc(size),
// jsepFree()
(ptr: number) => backend.free(ptr),
// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
backend.upload(dst, data);
}
},
// jsepCopyAsync(src, dst, size)
async (gpuDataId: number, dataOffset: number, size: number): Promise<void> => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`,
);
await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size));
},
// jsepCreateKernel
(kernelType: string, kernelId: number, attribute: unknown) =>
backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))),
// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
// jsepRun
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string | null>>) => {
LOG_DEBUG(
'verbose',
() =>
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`,
);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
},
// jsepCaptureBegin
() => backend.captureBegin(),
// jsepCaptureEnd
() => backend.captureEnd(),
// jsepReplay
() => backend.replay(),
]);
} else {
jsepInit('webnn');
}
};