UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

271 lines (226 loc) 9.13 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { resolveBackend, SessionHandlerType } from './backend'; import { ExecutionPlan } from './execution-plan'; import { Graph } from './graph'; import { Profiler } from './instrument'; import { Model } from './model'; import { Operator } from './operators'; import { Tensor } from './tensor'; export declare namespace Session { export interface Config { backendHint?: string; profiler?: Profiler.Config; } export interface Context { profiler: Readonly<Profiler>; graphInputTypes?: Tensor.DataType[]; graphInputDims?: Array<readonly number[]>; } } export class Session { constructor(config: Session.Config = {}) { this._initialized = false; this.backendHint = config.backendHint; this.profiler = Profiler.create(config.profiler); this.context = { profiler: this.profiler, graphInputTypes: [], graphInputDims: [] }; } get inputNames(): readonly string[] { return this._model.graph.getInputNames(); } get outputNames(): readonly string[] { return this._model.graph.getOutputNames(); } startProfiling() { this.profiler.start(); } endProfiling() { this.profiler.stop(); } async loadModel(uri: string): Promise<void>; async loadModel(buffer: ArrayBuffer, byteOffset?: number, length?: number): Promise<void>; async loadModel(buffer: Uint8Array): Promise<void>; async loadModel(arg: string | ArrayBuffer | Uint8Array, byteOffset?: number, length?: number): Promise<void> { await this.profiler.event('session', 'Session.loadModel', async () => { // resolve backend and session handler const backend = await resolveBackend(this.backendHint); this.sessionHandler = backend.createSessionHandler(this.context); this._model = new Model(); if (typeof arg === 'string') { const isOrtFormat = arg.endsWith('.ort'); if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node const { readFile } = require('node:fs/promises'); const buf = await readFile(arg); this.initialize(buf, isOrtFormat); } else { // browser const response = await fetch(arg); const buf = await response.arrayBuffer(); this.initialize(new Uint8Array(buf), isOrtFormat); } } else if (!ArrayBuffer.isView(arg)) { // load model from ArrayBuffer const arr = new Uint8Array(arg, byteOffset || 0, length || arg.byteLength); this.initialize(arr); } else { // load model from Uint8array this.initialize(arg); } }); } private initialize(modelProtoBlob: Uint8Array, isOrtFormat?: boolean): void { if (this._initialized) { throw new Error('already initialized'); } this.profiler.event('session', 'Session.initialize', () => { // load graph const graphInitializer = this.sessionHandler.transformGraph ? (this.sessionHandler as Graph.Initializer) : undefined; this._model.load(modelProtoBlob, graphInitializer, isOrtFormat); // graph is completely initialzied at this stage , let the interested handlers know if (this.sessionHandler.onGraphInitialized) { this.sessionHandler.onGraphInitialized(this._model.graph); } // initialize each operator in the graph this.initializeOps(this._model.graph); // instantiate an ExecutionPlan object to be used by the Session object this._executionPlan = new ExecutionPlan(this._model.graph, this._ops, this.profiler); }); this._initialized = true; } async run(inputs: Map<string, Tensor> | Tensor[]): Promise<Map<string, Tensor>> { if (!this._initialized) { throw new Error('session not initialized yet'); } return this.profiler.event('session', 'Session.run', async () => { const inputTensors = this.normalizeAndValidateInputs(inputs); const outputTensors = await this._executionPlan.execute(this.sessionHandler, inputTensors); return this.createOutput(outputTensors); }); } private normalizeAndValidateInputs(inputs: Map<string, Tensor> | Tensor[]): Tensor[] { const modelInputNames = this._model.graph.getInputNames(); // normalize inputs // inputs: Tensor[] if (Array.isArray(inputs)) { if (inputs.length !== modelInputNames.length) { throw new Error(`incorrect input array length: expected ${modelInputNames.length} but got ${inputs.length}`); } } // convert map to array // inputs: Map<string, Tensor> else { if (inputs.size !== modelInputNames.length) { throw new Error(`incorrect input map size: expected ${modelInputNames.length} but got ${inputs.size}`); } const sortedInputs = new Array<Tensor>(inputs.size); let sortedInputsIndex = 0; for (let i = 0; i < modelInputNames.length; ++i) { const tensor = inputs.get(modelInputNames[i]); if (!tensor) { throw new Error(`missing input tensor for: '${name}'`); } sortedInputs[sortedInputsIndex++] = tensor; } inputs = sortedInputs; } // validate dims requirements // First session run - graph input data is not cached for the session if ( !this.context.graphInputTypes || this.context.graphInputTypes.length === 0 || !this.context.graphInputDims || this.context.graphInputDims.length === 0 ) { const modelInputIndices = this._model.graph.getInputIndices(); const modelValues = this._model.graph.getValues(); const graphInputDims = new Array<readonly number[]>(modelInputIndices.length); for (let i = 0; i < modelInputIndices.length; ++i) { const graphInput = modelValues[modelInputIndices[i]]; graphInputDims[i] = graphInput.type!.shape.dims; // cached for second and subsequent runs. // Some parts of the framework works on the assumption that the graph and types and shapes are static this.context.graphInputTypes!.push(graphInput.type!.tensorType); this.context.graphInputDims!.push(inputs[i].dims); } this.validateInputTensorDims(graphInputDims, inputs, true); } // Second and subsequent session runs - graph input data is cached for the session else { this.validateInputTensorDims(this.context.graphInputDims, inputs, false); } // validate types requirement this.validateInputTensorTypes(this.context.graphInputTypes!, inputs); return inputs; } private validateInputTensorTypes(graphInputTypes: Tensor.DataType[], givenInputs: Tensor[]) { for (let i = 0; i < givenInputs.length; i++) { const expectedType = graphInputTypes[i]; const actualType = givenInputs[i].type; if (expectedType !== actualType) { throw new Error(`input tensor[${i}] check failed: expected type '${expectedType}' but got ${actualType}`); } } } private validateInputTensorDims( graphInputDims: Array<readonly number[]>, givenInputs: Tensor[], noneDimSupported: boolean, ) { for (let i = 0; i < givenInputs.length; i++) { const expectedDims = graphInputDims[i]; const actualDims = givenInputs[i].dims; if (!this.compareTensorDims(expectedDims, actualDims, noneDimSupported)) { throw new Error( `input tensor[${i}] check failed: expected shape '[${expectedDims.join(',')}]' but got [${actualDims.join( ',', )}]`, ); } } } private compareTensorDims( expectedDims: readonly number[], actualDims: readonly number[], noneDimSupported: boolean, ): boolean { if (expectedDims.length !== actualDims.length) { return false; } for (let i = 0; i < expectedDims.length; ++i) { if (expectedDims[i] !== actualDims[i] && (!noneDimSupported || expectedDims[i] !== 0)) { // data shape mis-match AND not a 'None' dimension. return false; } } return true; } private createOutput(outputTensors: Tensor[]): Map<string, Tensor> { const modelOutputNames = this._model.graph.getOutputNames(); if (outputTensors.length !== modelOutputNames.length) { throw new Error('expected number of outputs do not match number of generated outputs'); } const output = new Map<string, Tensor>(); for (let i = 0; i < modelOutputNames.length; ++i) { output.set(modelOutputNames[i], outputTensors[i]); } return output; } private initializeOps(graph: Graph): void { const nodes = graph.getNodes(); this._ops = new Array(nodes.length); for (let i = 0; i < nodes.length; i++) { this._ops[i] = this.sessionHandler.resolve(nodes[i], this._model.opsets, graph); } } private _model: Model; private _initialized: boolean; private _ops: Operator[]; private _executionPlan: ExecutionPlan; private backendHint?: string; private sessionHandler: SessionHandlerType; private context: Session.Context; private profiler: Readonly<Profiler>; }