@fugood/onnxruntime-react-native
Version:
ONNX Runtime bridge for react native
220 lines (160 loc) • 6.9 kB
JavaScript
function _defineProperty(obj, key, value) { if (key in obj) { Object.defineProperty(obj, key, { value: value, enumerable: true, configurable: true, writable: true }); } else { obj[key] = value; } return obj; }
function _classPrivateFieldInitSpec(obj, privateMap, value) { _checkPrivateRedeclaration(obj, privateMap); privateMap.set(obj, value); }
function _checkPrivateRedeclaration(obj, privateCollection) { if (privateCollection.has(obj)) { throw new TypeError("Cannot initialize the same private elements twice on an object"); } }
function _classPrivateFieldGet(receiver, privateMap) { var descriptor = _classExtractFieldDescriptor(receiver, privateMap, "get"); return _classApplyDescriptorGet(receiver, descriptor); }
function _classApplyDescriptorGet(receiver, descriptor) { if (descriptor.get) { return descriptor.get.call(receiver); } return descriptor.value; }
function _classPrivateFieldSet(receiver, privateMap, value) { var descriptor = _classExtractFieldDescriptor(receiver, privateMap, "set"); _classApplyDescriptorSet(receiver, descriptor, value); return value; }
function _classExtractFieldDescriptor(receiver, privateMap, action) { if (!privateMap.has(receiver)) { throw new TypeError("attempted to " + action + " private field on non-instance"); } return privateMap.get(receiver); }
function _classApplyDescriptorSet(receiver, descriptor, value) { if (descriptor.set) { descriptor.set.call(receiver, value); } else { if (!descriptor.writable) { throw new TypeError("attempted to set read only private field"); } descriptor.value = value; } }
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { Tensor } from '@fugood/onnxruntime-common';
import { Platform } from 'react-native';
import { binding, jsiHelper } from './binding';
const tensorTypeToTypedArray = type => {
switch (type) {
case 'float32':
return Float32Array;
case 'int8':
return Int8Array;
case 'uint8':
return Uint8Array;
case 'int16':
return Int16Array;
case 'int32':
return Int32Array;
case 'bool':
return Int8Array;
case 'float64':
return Float64Array;
case 'int64':
/* global BigInt64Array */
/* eslint no-undef: ["error", { "typeof": true }] */
return BigInt64Array;
default:
throw new Error(`unsupported type: ${type}`);
}
};
const normalizePath = path => {
// remove 'file://' prefix in iOS
if (Platform.OS === 'ios' && path.toLowerCase().startsWith('file://')) {
return path.substring(7);
}
return path;
};
var _inferenceSession = /*#__PURE__*/new WeakMap();
var _key = /*#__PURE__*/new WeakMap();
var _pathOrBuffer = /*#__PURE__*/new WeakMap();
class OnnxruntimeSessionHandler {
constructor(pathOrBuffer) {
_classPrivateFieldInitSpec(this, _inferenceSession, {
writable: true,
value: void 0
});
_classPrivateFieldInitSpec(this, _key, {
writable: true,
value: void 0
});
_classPrivateFieldInitSpec(this, _pathOrBuffer, {
writable: true,
value: void 0
});
_defineProperty(this, "inputNames", void 0);
_defineProperty(this, "outputNames", void 0);
_classPrivateFieldSet(this, _inferenceSession, binding);
_classPrivateFieldSet(this, _pathOrBuffer, pathOrBuffer);
_classPrivateFieldSet(this, _key, '');
this.inputNames = [];
this.outputNames = [];
}
async loadModel(options) {
try {
let results; // load a model
if (typeof _classPrivateFieldGet(this, _pathOrBuffer) === 'string') {
results = await _classPrivateFieldGet(this, _inferenceSession).loadModel(normalizePath(_classPrivateFieldGet(this, _pathOrBuffer)), options);
} else {
if (!_classPrivateFieldGet(this, _inferenceSession).loadModelFromBlob) {
throw new Error('Native module method "loadModelFromBlob" is not defined');
}
const modelBlob = jsiHelper.storeArrayBuffer(_classPrivateFieldGet(this, _pathOrBuffer));
results = await _classPrivateFieldGet(this, _inferenceSession).loadModelFromBlob(modelBlob, options);
} // resolve promise if onnxruntime session is successfully created
_classPrivateFieldSet(this, _key, results.key);
this.inputNames = results.inputNames;
this.outputNames = results.outputNames;
} catch (e) {
throw new Error(`Can't load a model: ${e.message}`);
}
}
async dispose() {
return _classPrivateFieldGet(this, _inferenceSession).dispose(_classPrivateFieldGet(this, _key));
}
startProfiling() {// TODO: implement profiling
}
endProfiling() {// TODO: implement profiling
}
async run(feeds, fetches, options) {
const outputNames = [];
for (const name in fetches) {
if (Object.prototype.hasOwnProperty.call(fetches, name)) {
if (fetches[name]) {
throw new Error('Preallocated output is not supported and only names as string array is allowed as parameter');
}
outputNames.push(name);
}
}
const input = this.encodeFeedsType(feeds);
const results = await _classPrivateFieldGet(this, _inferenceSession).run(_classPrivateFieldGet(this, _key), input, outputNames, options);
const output = this.decodeReturnType(results);
return output;
}
encodeFeedsType(feeds) {
const returnValue = {};
for (const key in feeds) {
if (Object.hasOwnProperty.call(feeds, key)) {
let data;
if (Array.isArray(feeds[key].data)) {
data = feeds[key].data;
} else {
const buffer = feeds[key].data.buffer;
data = jsiHelper.storeArrayBuffer(buffer);
}
returnValue[key] = {
dims: feeds[key].dims,
type: feeds[key].type,
data
};
}
}
return returnValue;
}
decodeReturnType(results) {
const returnValue = {};
for (const key in results) {
if (Object.hasOwnProperty.call(results, key)) {
let tensorData;
if (Array.isArray(results[key].data)) {
tensorData = results[key].data;
} else {
const buffer = jsiHelper.resolveArrayBuffer(results[key].data);
const typedArray = tensorTypeToTypedArray(results[key].type);
tensorData = new typedArray(buffer, buffer.byteOffset, buffer.byteLength / typedArray.BYTES_PER_ELEMENT);
}
returnValue[key] = new Tensor(results[key].type, tensorData, results[key].dims);
}
}
return returnValue;
}
}
class OnnxruntimeBackend {
async init() {
return Promise.resolve();
}
async createSessionHandler(pathOrBuffer, options) {
const handler = new OnnxruntimeSessionHandler(pathOrBuffer);
await handler.loadModel(options || {});
return handler;
}
}
export const onnxruntimeBackend = new OnnxruntimeBackend();
//# sourceMappingURL=backend.js.map