react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
79 lines (78 loc) • 2.9 kB
JavaScript
import { useCallback, useEffect, useMemo, useState } from 'react';
import { LLMController } from '../../controllers/LLMController';
/*
Hook version of LLMModule
*/
export const useLLM = ({ model, preventLoad = false, }) => {
const [token, setToken] = useState('');
const [response, setResponse] = useState('');
const [messageHistory, setMessageHistory] = useState([]);
const [isReady, setIsReady] = useState(false);
const [isGenerating, setIsGenerating] = useState(false);
const [downloadProgress, setDownloadProgress] = useState(0);
const [error, setError] = useState(null);
const tokenCallback = useCallback((newToken) => {
setToken(newToken);
setResponse((prevResponse) => prevResponse + newToken);
}, []);
const controllerInstance = useMemo(() => new LLMController({
tokenCallback: tokenCallback,
messageHistoryCallback: setMessageHistory,
isReadyCallback: setIsReady,
isGeneratingCallback: setIsGenerating,
}), [tokenCallback]);
useEffect(() => {
setDownloadProgress(0);
setError(null);
if (preventLoad)
return;
(async () => {
try {
await controllerInstance.load({
modelSource: model.modelSource,
tokenizerSource: model.tokenizerSource,
tokenizerConfigSource: model.tokenizerConfigSource,
onDownloadProgressCallback: setDownloadProgress,
});
}
catch (e) {
setError(e);
}
})();
return () => {
controllerInstance.delete();
};
}, [
controllerInstance,
model.modelSource,
model.tokenizerSource,
model.tokenizerConfigSource,
preventLoad,
]);
// memoization of returned functions
const configure = useCallback(({ chatConfig, toolsConfig, }) => controllerInstance.configure({ chatConfig, toolsConfig }), [controllerInstance]);
const generate = useCallback((messages, tools) => {
setResponse('');
return controllerInstance.generate(messages, tools);
}, [controllerInstance]);
const sendMessage = useCallback((message) => {
setResponse('');
return controllerInstance.sendMessage(message);
}, [controllerInstance]);
const deleteMessage = useCallback((index) => controllerInstance.deleteMessage(index), [controllerInstance]);
const interrupt = useCallback(() => controllerInstance.interrupt(), [controllerInstance]);
return {
messageHistory,
response,
token,
isReady,
isGenerating,
downloadProgress,
error,
configure: configure,
generate: generate,
sendMessage: sendMessage,
deleteMessage: deleteMessage,
interrupt: interrupt,
};
};