UNPKG

@llamaindex/core

Version:
205 lines (200 loc) 8.69 kB
import { Settings } from '../../global/dist/index.js'; import { SimpleChatStore } from '../../storage/chat-store/dist/index.js'; import { extractText, messagesToHistory } from '../../utils/dist/index.js'; import { tokenizers } from '@llamaindex/env/tokenizers'; import { defaultSummaryPrompt } from '../../prompts/dist/index.js'; const DEFAULT_TOKEN_LIMIT_RATIO = 0.75; const DEFAULT_CHAT_STORE_KEY = "chat_history"; /** * A ChatMemory is used to keep the state of back and forth chat messages */ class BaseMemory { _tokenCountForMessages(messages) { if (messages.length === 0) { return 0; } const tokenizer = Settings.tokenizer; const str = messages.map((m)=>extractText(m.content)).join(" "); return tokenizer.encode(str).length; } } class BaseChatStoreMemory extends BaseMemory { constructor(chatStore = new SimpleChatStore(), chatStoreKey = DEFAULT_CHAT_STORE_KEY){ super(), this.chatStore = chatStore, this.chatStoreKey = chatStoreKey; } getAllMessages() { return this.chatStore.getMessages(this.chatStoreKey); } put(messages) { this.chatStore.addMessage(this.chatStoreKey, messages); } set(messages) { this.chatStore.setMessages(this.chatStoreKey, messages); } reset() { this.chatStore.deleteMessages(this.chatStoreKey); } } class ChatMemoryBuffer extends BaseChatStoreMemory { constructor(options){ super(options?.chatStore, options?.chatStoreKey); const llm = options?.llm ?? Settings.llm; const contextWindow = llm.metadata.contextWindow; this.tokenLimit = options?.tokenLimit ?? Math.ceil(contextWindow * DEFAULT_TOKEN_LIMIT_RATIO); if (options?.chatHistory) { this.chatStore.setMessages(this.chatStoreKey, options.chatHistory); } } async getMessages(transientMessages, initialTokenCount = 0) { const messages = await this.getAllMessages(); if (initialTokenCount > this.tokenLimit) { throw new Error("Initial token count exceeds token limit"); } // Add input messages as transient messages const messagesWithInput = transientMessages ? [ ...transientMessages, ...messages ] : messages; let messageCount = messagesWithInput.length; let currentMessages = messagesWithInput.slice(-messageCount); let tokenCount = this._tokenCountForMessages(messagesWithInput) + initialTokenCount; while(tokenCount > this.tokenLimit && messageCount > 1){ messageCount -= 1; if (messagesWithInput.at(-messageCount).role === "assistant") { messageCount -= 1; } currentMessages = messagesWithInput.slice(-messageCount); tokenCount = this._tokenCountForMessages(currentMessages) + initialTokenCount; } if (tokenCount > this.tokenLimit && messageCount <= 0) { return []; } return messagesWithInput.slice(-messageCount); } } class ChatSummaryMemoryBuffer extends BaseMemory { constructor(options){ super(); this.messages = options?.messages ?? []; this.summaryPrompt = options?.summaryPrompt ?? defaultSummaryPrompt; this.llm = options?.llm ?? Settings.llm; if (!this.llm.metadata.maxTokens) { throw new Error("LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM."); } this.tokenizer = options?.tokenizer ?? tokenizers.tokenizer(); this.tokensToSummarize = this.llm.metadata.contextWindow - this.llm.metadata.maxTokens; if (this.tokensToSummarize < this.llm.metadata.contextWindow * 0.25) { throw new Error("The number of tokens that trigger the summarize process are less than 25% of the context window. Try lowering maxTokens or use a model with a larger context window."); } } async summarize() { // get the conversation messages to create summary const messagesToSummarize = this.calcConversationMessages(); let promptMessages; do { promptMessages = [ { content: this.summaryPrompt.format({ context: messagesToHistory(messagesToSummarize) }), role: "user", options: {} } ]; // remove oldest message until the chat history is short enough for the context window messagesToSummarize.shift(); }while (this.tokenizer.encode(promptMessages[0].content).length > this.tokensToSummarize) const response = await this.llm.chat({ messages: promptMessages }); return { content: response.message.content, role: "memory" }; } // Find last summary message get lastSummaryIndex() { const reversedMessages = this.messages.slice().reverse(); const index = reversedMessages.findIndex((message)=>message.role === "memory"); if (index === -1) { return null; } return this.messages.length - 1 - index; } getLastSummary() { const lastSummaryIndex = this.lastSummaryIndex; return lastSummaryIndex ? this.messages[lastSummaryIndex] : null; } get systemMessages() { // get array of all system messages return this.messages.filter((message)=>message.role === "system"); } get nonSystemMessages() { // get array of all non-system messages return this.messages.filter((message)=>message.role !== "system"); } /** * Calculates the messages that describe the conversation so far. * If there's no memory, all non-system messages are used. * If there's a memory, uses all messages after the last summary message. */ calcConversationMessages(transformSummary) { const lastSummaryIndex = this.lastSummaryIndex; if (!lastSummaryIndex) { // there's no memory, so just use all non-system messages return this.nonSystemMessages; } else { // there's a memory, so use all messages after the last summary message // and convert summary message so it can be send to the LLM const summaryMessage = transformSummary ? { content: `Summary of the conversation so far: ${this.messages[lastSummaryIndex].content}`, role: "system" } : this.messages[lastSummaryIndex]; return [ summaryMessage, ...this.messages.slice(lastSummaryIndex + 1) ]; } } calcCurrentRequestMessages(transientMessages) { // currently, we're sending: // system messages first, then transient messages and then the messages that describe the conversation so far return [ ...this.systemMessages, ...transientMessages ? transientMessages : [], ...this.calcConversationMessages(true) ]; } reset() { this.messages = []; } async getMessages(transientMessages) { const requestMessages = this.calcCurrentRequestMessages(transientMessages); // get tokens of current request messages and the transient messages const tokens = requestMessages.reduce((count, message)=>count + this.tokenizer.encode(extractText(message.content)).length, 0); if (tokens > this.tokensToSummarize) { // if there are too many tokens for the next request, call summarize const memoryMessage = await this.summarize(); const lastMessage = this.messages.at(-1); if (lastMessage && lastMessage.role === "user") { // if last message is a user message, ensure that it's sent after the new memory message this.messages.pop(); this.messages.push(memoryMessage); this.messages.push(lastMessage); } else { // otherwise just add the memory message this.messages.push(memoryMessage); } // TODO: we still might have too many tokens // e.g. too large system messages or transient messages // how should we deal with that? return this.calcCurrentRequestMessages(transientMessages); } return requestMessages; } async getAllMessages() { return this.getMessages(); } put(message) { this.messages.push(message); } } export { BaseMemory, ChatMemoryBuffer, ChatSummaryMemoryBuffer };