react-native-executorch
Version:
An easy way to run AI models in react native with ExecuTorch
84 lines (83 loc) • 2.62 kB
JavaScript
import { useCallback, useEffect, useRef, useState } from 'react';
import { LLM } from '../../native/RnExecutorchModules';
import { fetchResource } from '../../utils/fetchResource';
import { DEFAULT_CONTEXT_WINDOW_LENGTH, DEFAULT_MESSAGE_HISTORY, DEFAULT_SYSTEM_PROMPT, EOT_TOKEN } from '../../constants/llamaDefaults';
const interrupt = () => {
LLM.interrupt();
};
export const useLLM = ({
modelSource,
tokenizerSource,
systemPrompt = DEFAULT_SYSTEM_PROMPT,
messageHistory = DEFAULT_MESSAGE_HISTORY,
contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH
}) => {
const [error, setError] = useState(null);
const [isReady, setIsReady] = useState(false);
const [isGenerating, setIsGenerating] = useState(false);
const [response, setResponse] = useState('');
const [downloadProgress, setDownloadProgress] = useState(0);
const tokenGeneratedListener = useRef(null);
useEffect(() => {
const loadModel = async () => {
try {
setIsReady(false);
const tokenizerFileUri = await fetchResource(tokenizerSource);
const modelFileUri = await fetchResource(modelSource, setDownloadProgress);
await LLM.loadLLM(modelFileUri, tokenizerFileUri, systemPrompt, messageHistory, contextWindowLength);
setIsReady(true);
tokenGeneratedListener.current = LLM.onToken(data => {
if (!data) {
return;
}
if (data !== EOT_TOKEN) {
setResponse(prevResponse => prevResponse + data);
} else {
setIsGenerating(false);
}
});
} catch (err) {
const message = err.message;
setIsReady(false);
setError(message);
} finally {
setDownloadProgress(0);
}
};
loadModel();
return () => {
tokenGeneratedListener.current?.remove();
tokenGeneratedListener.current = null;
LLM.deleteModule();
};
}, [modelSource, tokenizerSource, systemPrompt, messageHistory, contextWindowLength]);
const generate = useCallback(async input => {
if (!isReady) {
throw new Error('Model is still loading');
}
if (error) {
throw new Error(error);
}
try {
setResponse('');
setIsGenerating(true);
await LLM.runInference(input);
} catch (err) {
setIsGenerating(false);
throw new Error(err.message);
}
}, [isReady, error]);
return {
generate,
error,
isReady,
isGenerating,
isModelReady: isReady,
isModelGenerating: isGenerating,
response,
downloadProgress,
interrupt
};
};
//# sourceMappingURL=useLLM.js.map
;