UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

79 lines (78 loc) 2.9 kB
import { useCallback, useEffect, useMemo, useState } from 'react'; import { LLMController } from '../../controllers/LLMController'; /* Hook version of LLMModule */ export const useLLM = ({ model, preventLoad = false, }) => { const [token, setToken] = useState(''); const [response, setResponse] = useState(''); const [messageHistory, setMessageHistory] = useState([]); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); const [error, setError] = useState(null); const tokenCallback = useCallback((newToken) => { setToken(newToken); setResponse((prevResponse) => prevResponse + newToken); }, []); const controllerInstance = useMemo(() => new LLMController({ tokenCallback: tokenCallback, messageHistoryCallback: setMessageHistory, isReadyCallback: setIsReady, isGeneratingCallback: setIsGenerating, }), [tokenCallback]); useEffect(() => { setDownloadProgress(0); setError(null); if (preventLoad) return; (async () => { try { await controllerInstance.load({ modelSource: model.modelSource, tokenizerSource: model.tokenizerSource, tokenizerConfigSource: model.tokenizerConfigSource, onDownloadProgressCallback: setDownloadProgress, }); } catch (e) { setError(e); } })(); return () => { controllerInstance.delete(); }; }, [ controllerInstance, model.modelSource, model.tokenizerSource, model.tokenizerConfigSource, preventLoad, ]); // memoization of returned functions const configure = useCallback(({ chatConfig, toolsConfig, }) => controllerInstance.configure({ chatConfig, toolsConfig }), [controllerInstance]); const generate = useCallback((messages, tools) => { setResponse(''); return controllerInstance.generate(messages, tools); }, [controllerInstance]); const sendMessage = useCallback((message) => { setResponse(''); return controllerInstance.sendMessage(message); }, [controllerInstance]); const deleteMessage = useCallback((index) => controllerInstance.deleteMessage(index), [controllerInstance]); const interrupt = useCallback(() => controllerInstance.interrupt(), [controllerInstance]); return { messageHistory, response, token, isReady, isGenerating, downloadProgress, error, configure: configure, generate: generate, sendMessage: sendMessage, deleteMessage: deleteMessage, interrupt: interrupt, }; };