react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
104 lines (100 loc) • 3.83 kB
JavaScript
;
import { useCallback, useEffect, useState } from 'react';
import { LLMController } from '../../controllers/LLMController';
import { parseUnknownError } from '../../errors/errorUtils';
/**
* React hook for managing a Large Language Model (LLM) instance.
* @category Hooks
* @param props - Object containing model, tokenizer, and tokenizer config sources.
* @returns An object implementing the `LLMTypeMultimodal` interface when `model.capabilities` is provided, otherwise `LLMType`.
*/
export function useLLM({
model,
preventLoad = false
}) {
const [token, setToken] = useState('');
const [response, setResponse] = useState('');
const [messageHistory, setMessageHistory] = useState([]);
const [isReady, setIsReady] = useState(false);
const [isGenerating, setIsGenerating] = useState(false);
const [downloadProgress, setDownloadProgress] = useState(0);
const [error, setError] = useState(null);
const capabilitiesKey = model.capabilities?.join(',') ?? '';
const tokenCallback = useCallback(newToken => {
setToken(newToken);
setResponse(prevResponse => prevResponse + newToken);
}, []);
const [controllerInstance] = useState(() => new LLMController({
tokenCallback: tokenCallback,
messageHistoryCallback: setMessageHistory,
isReadyCallback: setIsReady,
isGeneratingCallback: setIsGenerating
}));
useEffect(() => {
setDownloadProgress(0);
setError(null);
if (preventLoad) return;
(async () => {
try {
await controllerInstance.load({
modelSource: model.modelSource,
tokenizerSource: model.tokenizerSource,
tokenizerConfigSource: model.tokenizerConfigSource,
capabilities: model.capabilities,
onDownloadProgressCallback: setDownloadProgress
});
} catch (e) {
setError(parseUnknownError(e));
}
})();
return () => {
if (controllerInstance.isReady) {
controllerInstance.delete();
}
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [controllerInstance, model.modelName, model.modelSource, model.tokenizerSource, model.tokenizerConfigSource, capabilitiesKey,
// intentional: serialized string to avoid array reference re-runs
preventLoad]);
// memoization of returned functions
const configure = useCallback(({
chatConfig,
toolsConfig,
generationConfig
}) => controllerInstance.configure({
chatConfig,
toolsConfig,
generationConfig
}), [controllerInstance]);
const generate = useCallback((messages, tools) => {
setResponse('');
return controllerInstance.generate(messages, tools);
}, [controllerInstance]);
const sendMessage = useCallback((message, media) => {
setResponse('');
return controllerInstance.sendMessage(message, media);
}, [controllerInstance]);
const deleteMessage = useCallback(index => controllerInstance.deleteMessage(index), [controllerInstance]);
const interrupt = useCallback(() => controllerInstance.interrupt(), [controllerInstance]);
const getGeneratedTokenCount = useCallback(() => controllerInstance.getGeneratedTokenCount(), [controllerInstance]);
const getPromptTokenCount = useCallback(() => controllerInstance.getPromptTokenCount(), [controllerInstance]);
const getTotalTokenCount = useCallback(() => controllerInstance.getTotalTokenCount(), [controllerInstance]);
return {
messageHistory,
response,
token,
isReady,
isGenerating,
downloadProgress,
error,
getGeneratedTokenCount: getGeneratedTokenCount,
getPromptTokenCount: getPromptTokenCount,
getTotalTokenCount: getTotalTokenCount,
configure: configure,
generate: generate,
sendMessage: sendMessage,
deleteMessage: deleteMessage,
interrupt: interrupt
};
}
//# sourceMappingURL=useLLM.js.map