UNPKG

react-native-executorch

Version:

An easy way to run AI models in react native with ExecuTorch

84 lines (83 loc) 2.62 kB
"use strict"; import { useCallback, useEffect, useRef, useState } from 'react'; import { LLM } from '../../native/RnExecutorchModules'; import { fetchResource } from '../../utils/fetchResource'; import { DEFAULT_CONTEXT_WINDOW_LENGTH, DEFAULT_MESSAGE_HISTORY, DEFAULT_SYSTEM_PROMPT, EOT_TOKEN } from '../../constants/llamaDefaults'; const interrupt = () => { LLM.interrupt(); }; export const useLLM = ({ modelSource, tokenizerSource, systemPrompt = DEFAULT_SYSTEM_PROMPT, messageHistory = DEFAULT_MESSAGE_HISTORY, contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH }) => { const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [response, setResponse] = useState(''); const [downloadProgress, setDownloadProgress] = useState(0); const tokenGeneratedListener = useRef(null); useEffect(() => { const loadModel = async () => { try { setIsReady(false); const tokenizerFileUri = await fetchResource(tokenizerSource); const modelFileUri = await fetchResource(modelSource, setDownloadProgress); await LLM.loadLLM(modelFileUri, tokenizerFileUri, systemPrompt, messageHistory, contextWindowLength); setIsReady(true); tokenGeneratedListener.current = LLM.onToken(data => { if (!data) { return; } if (data !== EOT_TOKEN) { setResponse(prevResponse => prevResponse + data); } else { setIsGenerating(false); } }); } catch (err) { const message = err.message; setIsReady(false); setError(message); } finally { setDownloadProgress(0); } }; loadModel(); return () => { tokenGeneratedListener.current?.remove(); tokenGeneratedListener.current = null; LLM.deleteModule(); }; }, [modelSource, tokenizerSource, systemPrompt, messageHistory, contextWindowLength]); const generate = useCallback(async input => { if (!isReady) { throw new Error('Model is still loading'); } if (error) { throw new Error(error); } try { setResponse(''); setIsGenerating(true); await LLM.runInference(input); } catch (err) { setIsGenerating(false); throw new Error(err.message); } }, [isReady, error]); return { generate, error, isReady, isGenerating, isModelReady: isReady, isModelGenerating: isGenerating, response, downloadProgress, interrupt }; }; //# sourceMappingURL=useLLM.js.map