UNPKG

react-native-executorch

Version:

An easy way to run AI models in react native with ExecuTorch

125 lines (109 loc) 3.33 kB
import { useEffect, useState } from 'react'; import { fetchResource } from '../utils/fetchResource'; import { ETError, getError } from '../Error'; import { ETInput, Module } from '../types/common'; import { _ETModule } from '../native/RnExecutorchModules'; import { getTypeIdentifier } from '../types/common'; interface Props { modelSource: string | number; module: Module; } interface _Module { error: string | null; isReady: boolean; isGenerating: boolean; downloadProgress: number; forwardETInput: ( input: ETInput[] | ETInput, shape: number[][] | number[] ) => ReturnType<_ETModule['forward']>; forwardImage: (input: string) => Promise<any>; } export const useModule = ({ modelSource, module }: Props): _Module => { const [error, setError] = useState<null | string>(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); useEffect(() => { const loadModel = async () => { if (!modelSource) return; try { setIsReady(false); const fileUri = await fetchResource(modelSource, setDownloadProgress); await module.loadModule(fileUri); setIsReady(true); } catch (e) { setError(getError(e)); } }; loadModel(); }, [modelSource, module]); const forwardImage = async (input: string) => { if (!isReady) { throw new Error(getError(ETError.ModuleNotLoaded)); } if (isGenerating) { throw new Error(getError(ETError.ModelGenerating)); } try { setIsGenerating(true); const output = await module.forward(input); return output; } catch (e) { throw new Error(getError(e)); } finally { setIsGenerating(false); } }; const forwardETInput = async ( input: ETInput[] | ETInput, shape: number[][] | number[] ) => { if (!isReady) { throw new Error(getError(ETError.ModuleNotLoaded)); } if (isGenerating) { throw new Error(getError(ETError.ModelGenerating)); } // Since the native module expects an array of inputs and an array of shapes, // if the user provides a single ETInput, we want to "unsqueeze" the array so // the data is properly processed on the native side if (!Array.isArray(input)) { input = [input]; } if (!Array.isArray(shape[0])) { shape = [shape] as number[][]; } let inputTypeIdentifiers: any[] = []; let modelInputs: any[] = []; for (let idx = 0; idx < input.length; idx++) { let currentInputTypeIdentifier = getTypeIdentifier(input[idx] as ETInput); if (currentInputTypeIdentifier === -1) { throw new Error(getError(ETError.InvalidArgument)); } inputTypeIdentifiers.push(currentInputTypeIdentifier); modelInputs.push([...(input[idx] as ETInput)]); } try { setIsGenerating(true); const output = await module.forward( modelInputs, shape, inputTypeIdentifiers ); setIsGenerating(false); return output; } catch (e) { setIsGenerating(false); throw new Error(getError(e)); } }; return { error, isReady, isGenerating, downloadProgress, forwardETInput, forwardImage, }; };