UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

52 lines (49 loc) 1.99 kB
import { TensorPtr } from '../../types/common'; import { BaseModule } from '../BaseModule'; import { ResourceSource } from '../../types/common'; import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { Logger } from '../../common/Logger'; /** * General module for executing custom Executorch models. * @category Typescript API */ export class ExecutorchModule extends BaseModule { /** * Loads the model, where `modelSource` is a string, number, or object that specifies the location of the model binary. * Optionally accepts a download progress callback. * @param modelSource - Source of the model to be loaded. * @param onDownloadProgressCallback - Optional callback to monitor download progress. */ async load( modelSource: ResourceSource, onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise<void> { try { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, modelSource ); if (!paths?.[0]) { throw new RnExecutorchError( RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } this.nativeModule = await global.loadExecutorchModule(paths[0]); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); } } /** * Executes the model's forward pass, where input is an array of `TensorPtr` objects. * If the inference is successful, an array of tensor pointers is returned. * @param inputTensor - Array of input tensor pointers. * @returns An array of output tensor pointers. */ async forward(inputTensor: TensorPtr[]): Promise<TensorPtr[]> { return await this.forwardET(inputTensor); } }