UNPKG

react-native-executorch

Version:

An easy way to run AI models in react native with ExecuTorch

92 lines (90 loc) 2.71 kB
"use strict"; import { useEffect, useState } from 'react'; import { fetchResource } from '../utils/fetchResource'; import { ETError, getError } from '../Error'; import { getTypeIdentifier } from '../types/common'; export const useModule = ({ modelSource, module }) => { const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); useEffect(() => { const loadModel = async () => { if (!modelSource) return; try { setIsReady(false); const fileUri = await fetchResource(modelSource, setDownloadProgress); await module.loadModule(fileUri); setIsReady(true); } catch (e) { setError(getError(e)); } }; loadModel(); }, [modelSource, module]); const forwardImage = async input => { if (!isReady) { throw new Error(getError(ETError.ModuleNotLoaded)); } if (isGenerating) { throw new Error(getError(ETError.ModelGenerating)); } try { setIsGenerating(true); const output = await module.forward(input); return output; } catch (e) { throw new Error(getError(e)); } finally { setIsGenerating(false); } }; const forwardETInput = async (input, shape) => { if (!isReady) { throw new Error(getError(ETError.ModuleNotLoaded)); } if (isGenerating) { throw new Error(getError(ETError.ModelGenerating)); } // Since the native module expects an array of inputs and an array of shapes, // if the user provides a single ETInput, we want to "unsqueeze" the array so // the data is properly processed on the native side if (!Array.isArray(input)) { input = [input]; } if (!Array.isArray(shape[0])) { shape = [shape]; } let inputTypeIdentifiers = []; let modelInputs = []; for (let idx = 0; idx < input.length; idx++) { let currentInputTypeIdentifier = getTypeIdentifier(input[idx]); if (currentInputTypeIdentifier === -1) { throw new Error(getError(ETError.InvalidArgument)); } inputTypeIdentifiers.push(currentInputTypeIdentifier); modelInputs.push([...input[idx]]); } try { setIsGenerating(true); const output = await module.forward(modelInputs, shape, inputTypeIdentifiers); setIsGenerating(false); return output; } catch (e) { setIsGenerating(false); throw new Error(getError(e)); } }; return { error, isReady, isGenerating, downloadProgress, forwardETInput, forwardImage }; }; //# sourceMappingURL=useModule.js.map