UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

243 lines (239 loc) 8.24 kB
"use strict"; import { ResourceFetcher } from '../utils/ResourceFetcher'; import { ETError, getError } from '../Error'; import { Template } from '@huggingface/jinja'; import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults'; import { SPECIAL_TOKENS } from '../types/llm'; import { parseToolCall } from '../utils/llm'; import { Logger } from '../common/Logger'; import { readAsStringAsync } from 'expo-file-system/legacy'; export class LLMController { chatConfig = DEFAULT_CHAT_CONFIG; _response = ''; _isReady = false; _isGenerating = false; _messageHistory = []; // User callbacks constructor({ tokenCallback, responseCallback, messageHistoryCallback, isReadyCallback, isGeneratingCallback }) { if (responseCallback !== undefined) { Logger.warn('Passing response callback is deprecated and will be removed in 0.6.0'); } this.tokenCallback = token => { tokenCallback?.(token); }; this.responseCallback = response => { this._response = response; responseCallback?.(response); }; this.messageHistoryCallback = messageHistory => { this._messageHistory = messageHistory; messageHistoryCallback?.(messageHistory); }; this.isReadyCallback = isReady => { this._isReady = isReady; isReadyCallback?.(isReady); }; this.isGeneratingCallback = isGenerating => { this._isGenerating = isGenerating; isGeneratingCallback?.(isGenerating); }; } get response() { return this._response; } get isReady() { return this._isReady; } get isGenerating() { return this._isGenerating; } get messageHistory() { return this._messageHistory; } async load({ modelSource, tokenizerSource, tokenizerConfigSource, onDownloadProgressCallback }) { // reset inner state when loading new model this.responseCallback(''); this.messageHistoryCallback(this.chatConfig.initialMessageHistory); this.isGeneratingCallback(false); this.isReadyCallback(false); try { const tokenizersPromise = ResourceFetcher.fetch(undefined, tokenizerSource, tokenizerConfigSource); const modelPromise = ResourceFetcher.fetch(onDownloadProgressCallback, modelSource); const [tokenizersResults, modelResult] = await Promise.all([tokenizersPromise, modelPromise]); const tokenizerPath = tokenizersResults?.[0]; const tokenizerConfigPath = tokenizersResults?.[1]; const modelPath = modelResult?.[0]; if (!tokenizerPath || !tokenizerConfigPath || !modelPath) { throw new Error('Download interrupted!'); } this.tokenizerConfig = JSON.parse(await readAsStringAsync('file://' + tokenizerConfigPath)); this.nativeModule = global.loadLLM(modelPath, tokenizerPath); this.isReadyCallback(true); this.onToken = data => { if (!data) { return; } if (SPECIAL_TOKENS.EOS_TOKEN in this.tokenizerConfig && data.indexOf(this.tokenizerConfig.eos_token) >= 0) { data = data.replaceAll(this.tokenizerConfig.eos_token, ''); } if (SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig && data.indexOf(this.tokenizerConfig.pad_token) >= 0) { data = data.replaceAll(this.tokenizerConfig.pad_token, ''); } if (data.length === 0) { return; } this.tokenCallback(data); this.responseCallback(this._response + data); }; } catch (e) { this.isReadyCallback(false); throw new Error(getError(e)); } } setTokenCallback(tokenCallback) { this.tokenCallback = tokenCallback; } configure({ chatConfig, toolsConfig, generationConfig }) { this.chatConfig = { ...DEFAULT_CHAT_CONFIG, ...chatConfig }; this.toolsConfig = toolsConfig; if (generationConfig?.outputTokenBatchSize) { this.nativeModule.setCountInterval(generationConfig.outputTokenBatchSize); } if (generationConfig?.batchTimeInterval) { this.nativeModule.setTimeInterval(generationConfig.batchTimeInterval); } if (generationConfig?.temperature) { this.nativeModule.setTemperature(generationConfig.temperature); } if (generationConfig?.topp) { if (generationConfig.topp < 0 || generationConfig.topp > 1) { throw new Error(getError(ETError.InvalidConfig) + 'TopP has to be in range [0, 1].'); } this.nativeModule.setTopp(generationConfig.topp); } // reset inner state when loading new configuration this.responseCallback(''); this.messageHistoryCallback(this.chatConfig.initialMessageHistory); this.isGeneratingCallback(false); } delete() { if (this._isGenerating) { throw new Error(getError(ETError.ModelGenerating) + 'You cannot delete the model now. You need to interrupt first.'); } this.onToken = () => {}; this.nativeModule.unload(); this.isReadyCallback(false); this.isGeneratingCallback(false); } async forward(input) { if (!this._isReady) { throw new Error(getError(ETError.ModuleNotLoaded)); } if (this._isGenerating) { throw new Error(getError(ETError.ModelGenerating)); } try { this.responseCallback(''); this.isGeneratingCallback(true); await this.nativeModule.generate(input, this.onToken); } catch (e) { throw new Error(getError(e)); } finally { this.isGeneratingCallback(false); } } interrupt() { this.nativeModule.interrupt(); } getGeneratedTokenCount() { return this.nativeModule.getGeneratedTokenCount(); } async generate(messages, tools) { if (!this._isReady) { throw new Error(getError(ETError.ModuleNotLoaded)); } if (messages.length === 0) { throw new Error(`Empty 'messages' array!`); } if (messages[0] && messages[0].role !== 'system') { Logger.warn(`You are not providing system prompt. You can pass it in the first message using { role: 'system', content: YOUR_PROMPT }. Otherwise prompt from your model's chat template will be used.`); } const renderedChat = this.applyChatTemplate(messages, this.tokenizerConfig, tools, // eslint-disable-next-line camelcase { tools_in_user_message: false, add_generation_prompt: true }); await this.forward(renderedChat); } async sendMessage(message) { this.messageHistoryCallback([...this._messageHistory, { content: message, role: 'user' }]); const messageHistoryWithPrompt = [{ content: this.chatConfig.systemPrompt, role: 'system' }, ...this._messageHistory.slice(-this.chatConfig.contextWindowLength)]; await this.generate(messageHistoryWithPrompt, this.toolsConfig?.tools); if (!this.toolsConfig || this.toolsConfig.displayToolCalls) { this.messageHistoryCallback([...this._messageHistory, { content: this._response, role: 'assistant' }]); } if (!this.toolsConfig) { return; } const toolCalls = parseToolCall(this._response); for (const toolCall of toolCalls) { this.toolsConfig.executeToolCallback(toolCall).then(toolResponse => { if (toolResponse) { this.messageHistoryCallback([...this._messageHistory, { content: toolResponse, role: 'assistant' }]); } }); } } deleteMessage(index) { // we delete referenced message and all messages after it // so the model responses that used them are deleted as well const newMessageHistory = this._messageHistory.slice(0, index); this.messageHistoryCallback(newMessageHistory); } applyChatTemplate(messages, tokenizerConfig, tools, templateFlags) { if (!tokenizerConfig.chat_template) { throw Error("Tokenizer config doesn't include chat_template"); } const template = new Template(tokenizerConfig.chat_template); const specialTokens = Object.fromEntries(Object.values(SPECIAL_TOKENS).filter(key => key in tokenizerConfig).map(key => [key, tokenizerConfig[key]])); const result = template.render({ messages, tools, ...templateFlags, ...specialTokens }); return result; } } //# sourceMappingURL=LLMController.js.map