react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
459 lines (413 loc) • 13.5 kB
text/typescript
import { ResourceSource } from '../types/common';
import { ResourceFetcher } from '../utils/ResourceFetcher';
import { Template } from '@huggingface/jinja';
import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults';
import {
ChatConfig,
GenerationConfig,
LLMCapability,
LLMTool,
Message,
SPECIAL_TOKENS,
ToolsConfig,
} from '../types/llm';
import { parseToolCall } from '../utils/llm';
import { Logger } from '../common/Logger';
import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils';
import { RnExecutorchErrorCode } from '../errors/ErrorCodes';
export class LLMController {
private nativeModule: any;
private chatConfig: ChatConfig = DEFAULT_CHAT_CONFIG;
private toolsConfig: ToolsConfig | undefined;
private tokenizerConfig: any;
private onToken?: (token: string) => void;
private _isReady = false;
private _isGenerating = false;
private _messageHistory: Message[] = [];
// User callbacks
private tokenCallback: (token: string) => void;
private messageHistoryCallback: (messageHistory: Message[]) => void;
private isReadyCallback: (isReady: boolean) => void;
private isGeneratingCallback: (isGenerating: boolean) => void;
constructor({
tokenCallback,
messageHistoryCallback,
isReadyCallback,
isGeneratingCallback,
}: {
tokenCallback?: (token: string) => void;
messageHistoryCallback?: (messageHistory: Message[]) => void;
isReadyCallback?: (isReady: boolean) => void;
isGeneratingCallback?: (isGenerating: boolean) => void;
}) {
this.tokenCallback = (token) => {
tokenCallback?.(token);
};
this.messageHistoryCallback = (messageHistory) => {
this._messageHistory = messageHistory;
messageHistoryCallback?.(messageHistory);
};
this.isReadyCallback = (isReady) => {
this._isReady = isReady;
isReadyCallback?.(isReady);
};
this.isGeneratingCallback = (isGenerating) => {
this._isGenerating = isGenerating;
isGeneratingCallback?.(isGenerating);
};
}
public get isReady() {
return this._isReady;
}
public get isGenerating() {
return this._isGenerating;
}
public get messageHistory() {
return this._messageHistory;
}
public async load({
modelSource,
tokenizerSource,
tokenizerConfigSource,
capabilities,
onDownloadProgressCallback,
}: {
modelSource: ResourceSource;
tokenizerSource: ResourceSource;
tokenizerConfigSource: ResourceSource;
capabilities?: readonly LLMCapability[];
onDownloadProgressCallback?: (downloadProgress: number) => void;
}) {
// reset inner state when loading new model
this.messageHistoryCallback(this.chatConfig.initialMessageHistory);
this.isGeneratingCallback(false);
this.isReadyCallback(false);
try {
const tokenizersPromise = ResourceFetcher.fetch(
undefined,
tokenizerSource,
tokenizerConfigSource
);
const modelPromise = ResourceFetcher.fetch(
onDownloadProgressCallback,
modelSource
);
const [tokenizersResults, modelResult] = await Promise.all([
tokenizersPromise,
modelPromise,
]);
const tokenizerPath = tokenizersResults?.[0];
const tokenizerConfigPath = tokenizersResults?.[1];
const modelPath = modelResult?.[0];
if (!tokenizerPath || !tokenizerConfigPath || !modelPath) {
throw new RnExecutorchError(
RnExecutorchErrorCode.DownloadInterrupted,
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
);
}
this.tokenizerConfig = JSON.parse(
await ResourceFetcher.fs.readAsString(tokenizerConfigPath!)
);
if (this.nativeModule) {
this.nativeModule.unload();
}
this.nativeModule = await global.loadLLM(
modelPath,
tokenizerPath,
capabilities ?? []
);
this.isReadyCallback(true);
this.onToken = (data: string) => {
if (!data) {
return;
}
const filtered = this.filterSpecialTokens(data);
if (filtered.length === 0) {
return;
}
this.tokenCallback(filtered);
};
} catch (e) {
Logger.error('Load failed:', e);
this.isReadyCallback(false);
throw parseUnknownError(e);
}
}
public setTokenCallback(tokenCallback: (token: string) => void) {
this.tokenCallback = tokenCallback;
}
public configure({
chatConfig,
toolsConfig,
generationConfig,
}: {
chatConfig?: Partial<ChatConfig>;
toolsConfig?: ToolsConfig;
generationConfig?: GenerationConfig;
}) {
this.chatConfig = { ...DEFAULT_CHAT_CONFIG, ...chatConfig };
this.toolsConfig = toolsConfig;
if (generationConfig?.outputTokenBatchSize) {
this.nativeModule.setCountInterval(generationConfig.outputTokenBatchSize);
}
if (generationConfig?.batchTimeInterval) {
this.nativeModule.setTimeInterval(generationConfig.batchTimeInterval);
}
if (generationConfig?.temperature) {
this.nativeModule.setTemperature(generationConfig.temperature);
}
if (generationConfig?.topp) {
if (generationConfig.topp < 0 || generationConfig.topp > 1) {
throw new RnExecutorchError(
RnExecutorchErrorCode.InvalidConfig,
'Top P has to be in range [0, 1]'
);
}
this.nativeModule.setTopp(generationConfig.topp);
}
// reset inner state when loading new configuration
this.messageHistoryCallback(this.chatConfig.initialMessageHistory);
this.isGeneratingCallback(false);
}
private getImageToken(): string {
const token = this.tokenizerConfig.image_token;
if (!token) {
throw new RnExecutorchError(
RnExecutorchErrorCode.InvalidConfig,
"Tokenizer config is missing 'image_token'. Vision models require tokenizerConfigSource with an 'image_token' field."
);
}
return token;
}
private filterSpecialTokens(text: string): string {
let filtered = text;
if (
SPECIAL_TOKENS.EOS_TOKEN in this.tokenizerConfig &&
this.tokenizerConfig.eos_token
) {
filtered = filtered.replaceAll(this.tokenizerConfig.eos_token, '');
}
if (
SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig &&
this.tokenizerConfig.pad_token
) {
filtered = filtered.replaceAll(this.tokenizerConfig.pad_token, '');
}
return filtered;
}
public delete() {
if (this._isGenerating) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModelGenerating,
'You cannot delete the model now. You need ot interrupt it first.'
);
}
this.onToken = () => {};
if (this.nativeModule) {
this.nativeModule.unload();
}
this.isReadyCallback(false);
this.isGeneratingCallback(false);
}
public async forward(input: string, imagePaths?: string[]): Promise<string> {
if (!this._isReady) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
'The model is currently not loaded. Please load the model before calling forward().'
);
}
if (this._isGenerating) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModelGenerating,
'The model is currently generating. Please wait until previous model run is complete.'
);
}
try {
this.isGeneratingCallback(true);
this.nativeModule.reset();
const response =
imagePaths && imagePaths.length > 0
? await this.nativeModule.generateMultimodal(
input,
imagePaths,
this.getImageToken(),
this.onToken
)
: await this.nativeModule.generate(input, this.onToken);
return this.filterSpecialTokens(response);
} catch (e) {
throw parseUnknownError(e);
} finally {
this.isGeneratingCallback(false);
}
}
public interrupt() {
if (!this.nativeModule) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
"Cannot interrupt a model that's not loaded."
);
}
this.nativeModule.interrupt();
}
public getGeneratedTokenCount(): number {
if (!this.nativeModule) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
"Cannot get token count for a model that's not loaded."
);
}
return this.nativeModule.getGeneratedTokenCount();
}
public getPromptTokenCount(): number {
if (!this.nativeModule) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
"Cannot get prompt token count for a model that's not loaded."
);
}
return this.nativeModule.getPromptTokenCount();
}
public getTotalTokenCount(): number {
return this.getGeneratedTokenCount() + this.getPromptTokenCount();
}
public async generate(
messages: Message[],
tools?: LLMTool[]
): Promise<string> {
if (!this._isReady) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
'The model is currently not loaded. Please load the model before calling generate().'
);
}
if (messages.length === 0) {
throw new RnExecutorchError(
RnExecutorchErrorCode.InvalidUserInput,
'Messages array is empty!'
);
}
if (messages[0] && messages[0].role !== 'system') {
Logger.warn(
`You are not providing system prompt. You can pass it in the first message using { role: 'system', content: YOUR_PROMPT }. Otherwise prompt from your model's chat template will be used.`
);
}
const imagePaths = messages
.filter((m) => m.mediaPath)
.map((m) => m.mediaPath!);
const renderedChat: string = this.applyChatTemplate(
messages,
this.tokenizerConfig,
tools,
// eslint-disable-next-line camelcase
{ tools_in_user_message: false, add_generation_prompt: true }
);
return await this.forward(
renderedChat,
imagePaths.length > 0 ? imagePaths : undefined
);
}
public async sendMessage(
message: string,
media?: { imagePath?: string }
): Promise<string> {
const mediaPath = media?.imagePath;
const newMessage: Message = {
content: message,
role: 'user',
...(mediaPath ? { mediaPath } : {}),
};
const updatedHistory = [...this._messageHistory, newMessage];
this.messageHistoryCallback(updatedHistory);
const historyForTemplate = updatedHistory.map((m) =>
m.mediaPath
? {
...m,
content: [
{ type: 'image' },
{ type: 'text', text: m.content },
] as any,
}
: m
);
const visualTokenCount = this.nativeModule.getVisualTokenCount();
const countTokensCallback = (messages: Message[]) => {
const rendered = this.applyChatTemplate(
messages,
this.tokenizerConfig,
this.toolsConfig?.tools,
// eslint-disable-next-line camelcase
{ tools_in_user_message: false, add_generation_prompt: true }
);
const textTokens = this.nativeModule.countTextTokens(rendered);
const imageCount = messages.filter((m) => m.mediaPath).length;
return textTokens + imageCount * (visualTokenCount - 1);
};
const maxContextLength = this.nativeModule.getMaxContextLength();
const messageHistoryWithPrompt =
this.chatConfig.contextStrategy.buildContext(
this.chatConfig.systemPrompt,
historyForTemplate,
maxContextLength,
countTokensCallback
);
const response = await this.generate(
messageHistoryWithPrompt,
this.toolsConfig?.tools
);
if (!this.toolsConfig || this.toolsConfig.displayToolCalls) {
this.messageHistoryCallback([
...this._messageHistory,
{ content: response, role: 'assistant' },
]);
}
if (this.toolsConfig) {
const toolCalls = parseToolCall(response);
for (const toolCall of toolCalls) {
this.toolsConfig
.executeToolCallback(toolCall)
.then((toolResponse: string | null) => {
if (toolResponse) {
this.messageHistoryCallback([
...this._messageHistory,
{ content: toolResponse, role: 'assistant' },
]);
}
});
}
}
return response;
}
public deleteMessage(index: number) {
// we delete referenced message and all messages after it
// so the model responses that used them are deleted as well
const newMessageHistory = this._messageHistory.slice(0, index);
this.messageHistoryCallback(newMessageHistory);
}
private applyChatTemplate(
messages: Message[],
tokenizerConfig: any,
tools?: LLMTool[],
templateFlags?: Object
): string {
if (!tokenizerConfig.chat_template) {
throw new RnExecutorchError(
RnExecutorchErrorCode.TokenizerError,
"Tokenizer config doesn't include chat_template"
);
}
const template = new Template(tokenizerConfig.chat_template);
const specialTokens = Object.fromEntries(
Object.values(SPECIAL_TOKENS)
.filter((key) => key in tokenizerConfig)
.map((key) => [key, tokenizerConfig[key]])
);
const result = template.render({
messages,
tools,
...templateFlags,
...specialTokens,
});
return result;
}
}