react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
79 lines (73 loc) • 3.09 kB
text/typescript
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { StyleTransferModelName } from '../../types/styleTransfer';
import { ResourceSource, PixelData } from '../../types/common';
import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { Logger } from '../../common/Logger';
import { VisionModule } from './VisionModule';
/**
* Module for style transfer tasks.
* @category Typescript API
*/
export class StyleTransferModule extends VisionModule<PixelData | string> {
private constructor(nativeModule: unknown) {
super();
this.nativeModule = nativeModule;
}
/**
* Creates a style transfer instance for a built-in model.
* @param namedSources - An object specifying which built-in model to load and where to fetch it from.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `StyleTransferModule` instance.
*/
static async fromModelName(
namedSources: {
modelName: StyleTransferModelName;
modelSource: ResourceSource;
},
onDownloadProgress: (progress: number) => void = () => {}
): Promise<StyleTransferModule> {
try {
const paths = await ResourceFetcher.fetch(
onDownloadProgress,
namedSources.modelSource
);
if (!paths?.[0]) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
return new StyleTransferModule(await global.loadStyleTransfer(paths[0]));
} catch (error) {
Logger.error('Load failed:', error);
throw parseUnknownError(error);
}
}
/**
* Creates a style transfer instance with a user-provided model binary.
* Use this when working with a custom-exported model that is not one of the built-in presets.
* @remarks The native model contract for this method is not formally defined and may change
* between releases. Refer to the native source code for the current expected tensor interface.
* @param modelSource - A fetchable resource pointing to the model binary.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `StyleTransferModule` instance.
*/
static fromCustomModel(
modelSource: ResourceSource,
onDownloadProgress: (progress: number) => void = () => {}
): Promise<StyleTransferModule> {
return StyleTransferModule.fromModelName(
{ modelName: 'custom' as StyleTransferModelName, modelSource },
onDownloadProgress
);
}
async forward<O extends 'pixelData' | 'url' = 'pixelData'>(
input: string | PixelData,
outputType?: O
): Promise<O extends 'url' ? string : PixelData> {
return super.forward(input, outputType === 'url') as Promise<
O extends 'url' ? string : PixelData
>;
}
}