react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
243 lines (239 loc) • 8.24 kB
JavaScript
;
import { ResourceFetcher } from '../utils/ResourceFetcher';
import { ETError, getError } from '../Error';
import { Template } from '@huggingface/jinja';
import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults';
import { SPECIAL_TOKENS } from '../types/llm';
import { parseToolCall } from '../utils/llm';
import { Logger } from '../common/Logger';
import { readAsStringAsync } from 'expo-file-system/legacy';
export class LLMController {
chatConfig = DEFAULT_CHAT_CONFIG;
_response = '';
_isReady = false;
_isGenerating = false;
_messageHistory = [];
// User callbacks
constructor({
tokenCallback,
responseCallback,
messageHistoryCallback,
isReadyCallback,
isGeneratingCallback
}) {
if (responseCallback !== undefined) {
Logger.warn('Passing response callback is deprecated and will be removed in 0.6.0');
}
this.tokenCallback = token => {
tokenCallback?.(token);
};
this.responseCallback = response => {
this._response = response;
responseCallback?.(response);
};
this.messageHistoryCallback = messageHistory => {
this._messageHistory = messageHistory;
messageHistoryCallback?.(messageHistory);
};
this.isReadyCallback = isReady => {
this._isReady = isReady;
isReadyCallback?.(isReady);
};
this.isGeneratingCallback = isGenerating => {
this._isGenerating = isGenerating;
isGeneratingCallback?.(isGenerating);
};
}
get response() {
return this._response;
}
get isReady() {
return this._isReady;
}
get isGenerating() {
return this._isGenerating;
}
get messageHistory() {
return this._messageHistory;
}
async load({
modelSource,
tokenizerSource,
tokenizerConfigSource,
onDownloadProgressCallback
}) {
// reset inner state when loading new model
this.responseCallback('');
this.messageHistoryCallback(this.chatConfig.initialMessageHistory);
this.isGeneratingCallback(false);
this.isReadyCallback(false);
try {
const tokenizersPromise = ResourceFetcher.fetch(undefined, tokenizerSource, tokenizerConfigSource);
const modelPromise = ResourceFetcher.fetch(onDownloadProgressCallback, modelSource);
const [tokenizersResults, modelResult] = await Promise.all([tokenizersPromise, modelPromise]);
const tokenizerPath = tokenizersResults?.[0];
const tokenizerConfigPath = tokenizersResults?.[1];
const modelPath = modelResult?.[0];
if (!tokenizerPath || !tokenizerConfigPath || !modelPath) {
throw new Error('Download interrupted!');
}
this.tokenizerConfig = JSON.parse(await readAsStringAsync('file://' + tokenizerConfigPath));
this.nativeModule = global.loadLLM(modelPath, tokenizerPath);
this.isReadyCallback(true);
this.onToken = data => {
if (!data) {
return;
}
if (SPECIAL_TOKENS.EOS_TOKEN in this.tokenizerConfig && data.indexOf(this.tokenizerConfig.eos_token) >= 0) {
data = data.replaceAll(this.tokenizerConfig.eos_token, '');
}
if (SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig && data.indexOf(this.tokenizerConfig.pad_token) >= 0) {
data = data.replaceAll(this.tokenizerConfig.pad_token, '');
}
if (data.length === 0) {
return;
}
this.tokenCallback(data);
this.responseCallback(this._response + data);
};
} catch (e) {
this.isReadyCallback(false);
throw new Error(getError(e));
}
}
setTokenCallback(tokenCallback) {
this.tokenCallback = tokenCallback;
}
configure({
chatConfig,
toolsConfig,
generationConfig
}) {
this.chatConfig = {
...DEFAULT_CHAT_CONFIG,
...chatConfig
};
this.toolsConfig = toolsConfig;
if (generationConfig?.outputTokenBatchSize) {
this.nativeModule.setCountInterval(generationConfig.outputTokenBatchSize);
}
if (generationConfig?.batchTimeInterval) {
this.nativeModule.setTimeInterval(generationConfig.batchTimeInterval);
}
if (generationConfig?.temperature) {
this.nativeModule.setTemperature(generationConfig.temperature);
}
if (generationConfig?.topp) {
if (generationConfig.topp < 0 || generationConfig.topp > 1) {
throw new Error(getError(ETError.InvalidConfig) + 'TopP has to be in range [0, 1].');
}
this.nativeModule.setTopp(generationConfig.topp);
}
// reset inner state when loading new configuration
this.responseCallback('');
this.messageHistoryCallback(this.chatConfig.initialMessageHistory);
this.isGeneratingCallback(false);
}
delete() {
if (this._isGenerating) {
throw new Error(getError(ETError.ModelGenerating) + 'You cannot delete the model now. You need to interrupt first.');
}
this.onToken = () => {};
this.nativeModule.unload();
this.isReadyCallback(false);
this.isGeneratingCallback(false);
}
async forward(input) {
if (!this._isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (this._isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}
try {
this.responseCallback('');
this.isGeneratingCallback(true);
await this.nativeModule.generate(input, this.onToken);
} catch (e) {
throw new Error(getError(e));
} finally {
this.isGeneratingCallback(false);
}
}
interrupt() {
this.nativeModule.interrupt();
}
getGeneratedTokenCount() {
return this.nativeModule.getGeneratedTokenCount();
}
async generate(messages, tools) {
if (!this._isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (messages.length === 0) {
throw new Error(`Empty 'messages' array!`);
}
if (messages[0] && messages[0].role !== 'system') {
Logger.warn(`You are not providing system prompt. You can pass it in the first message using { role: 'system', content: YOUR_PROMPT }. Otherwise prompt from your model's chat template will be used.`);
}
const renderedChat = this.applyChatTemplate(messages, this.tokenizerConfig, tools,
// eslint-disable-next-line camelcase
{
tools_in_user_message: false,
add_generation_prompt: true
});
await this.forward(renderedChat);
}
async sendMessage(message) {
this.messageHistoryCallback([...this._messageHistory, {
content: message,
role: 'user'
}]);
const messageHistoryWithPrompt = [{
content: this.chatConfig.systemPrompt,
role: 'system'
}, ...this._messageHistory.slice(-this.chatConfig.contextWindowLength)];
await this.generate(messageHistoryWithPrompt, this.toolsConfig?.tools);
if (!this.toolsConfig || this.toolsConfig.displayToolCalls) {
this.messageHistoryCallback([...this._messageHistory, {
content: this._response,
role: 'assistant'
}]);
}
if (!this.toolsConfig) {
return;
}
const toolCalls = parseToolCall(this._response);
for (const toolCall of toolCalls) {
this.toolsConfig.executeToolCallback(toolCall).then(toolResponse => {
if (toolResponse) {
this.messageHistoryCallback([...this._messageHistory, {
content: toolResponse,
role: 'assistant'
}]);
}
});
}
}
deleteMessage(index) {
// we delete referenced message and all messages after it
// so the model responses that used them are deleted as well
const newMessageHistory = this._messageHistory.slice(0, index);
this.messageHistoryCallback(newMessageHistory);
}
applyChatTemplate(messages, tokenizerConfig, tools, templateFlags) {
if (!tokenizerConfig.chat_template) {
throw Error("Tokenizer config doesn't include chat_template");
}
const template = new Template(tokenizerConfig.chat_template);
const specialTokens = Object.fromEntries(Object.values(SPECIAL_TOKENS).filter(key => key in tokenizerConfig).map(key => [key, tokenizerConfig[key]]));
const result = template.render({
messages,
tools,
...templateFlags,
...specialTokens
});
return result;
}
}
//# sourceMappingURL=LLMController.js.map