react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
52 lines (49 loc) • 1.99 kB
text/typescript
import { TensorPtr } from '../../types/common';
import { BaseModule } from '../BaseModule';
import { ResourceSource } from '../../types/common';
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
import { Logger } from '../../common/Logger';
/**
* General module for executing custom Executorch models.
* @category Typescript API
*/
export class ExecutorchModule extends BaseModule {
/**
* Loads the model, where `modelSource` is a string, number, or object that specifies the location of the model binary.
* Optionally accepts a download progress callback.
* @param modelSource - Source of the model to be loaded.
* @param onDownloadProgressCallback - Optional callback to monitor download progress.
*/
async load(
modelSource: ResourceSource,
onDownloadProgressCallback: (progress: number) => void = () => {}
): Promise<void> {
try {
const paths = await ResourceFetcher.fetch(
onDownloadProgressCallback,
modelSource
);
if (!paths?.[0]) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
this.nativeModule = await global.loadExecutorchModule(paths[0]);
} catch (error) {
Logger.error('Load failed:', error);
throw parseUnknownError(error);
}
}
/**
* Executes the model's forward pass, where input is an array of `TensorPtr` objects.
* If the inference is successful, an array of tensor pointers is returned.
* @param inputTensor - Array of input tensor pointers.
* @returns An array of output tensor pointers.
*/
async forward(inputTensor: TensorPtr[]): Promise<TensorPtr[]> {
return await this.forwardET(inputTensor);
}
}