UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

254 lines 14.6 kB
import { isChatModelResponseFunctionCall, isChatModelResponseSegment } from "../../../../types.js"; import { findCharacterRemovalCountToFitChatHistoryInContext } from "../../../../utils/findCharacterRemovalCountToFitChatHistoryInContext.js"; import { truncateLlamaTextAndRoundToWords, truncateTextAndRoundToWords } from "../../../../utils/truncateTextAndRoundToWords.js"; import { LlamaText } from "../../../../utils/LlamaText.js"; export async function eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy({ chatHistory, maxTokensCount, tokenizer, chatWrapper, lastShiftMetadata }) { let initialCharactersRemovalCount = 0; if (isCalculationMetadata(lastShiftMetadata)) initialCharactersRemovalCount = lastShiftMetadata.removedCharactersNumber; const { removedCharactersCount, compressedChatHistory } = await findCharacterRemovalCountToFitChatHistoryInContext({ chatHistory, tokensCountToFit: maxTokensCount, initialCharactersRemovalCount, tokenizer, chatWrapper, failedCompressionErrorMessage: "Failed to compress chat history for context shift due to a too long prompt or system message that cannot be compressed without affecting the generation quality. " + "Consider increasing the context size or shortening the long prompt or system message.", compressChatHistory({ chatHistory, charactersToRemove, estimatedCharactersPerToken }) { const res = chatHistory.map((item) => structuredClone(item)); let charactersLeftToRemove = charactersToRemove; function compressFunctionCalls() { for (let i = res.length - 1; i >= 0 && charactersLeftToRemove > 0; i--) { const historyItem = res[i]; if (historyItem.type !== "model") continue; for (let t = historyItem.response.length - 1; t >= 0 && charactersLeftToRemove > 0; t--) { const item = historyItem.response[t]; if (typeof item === "string" || item.type !== "functionCall") continue; if (item.rawCall == null) continue; const originalRawCallTokensLength = LlamaText.fromJSON(item.rawCall).tokenize(tokenizer, "trimLeadingSpace").length; const newRawCallText = chatWrapper.generateFunctionCall(item.name, item.params); const newRawCallTextTokensLength = newRawCallText.tokenize(tokenizer, "trimLeadingSpace").length; if (newRawCallTextTokensLength < originalRawCallTokensLength) { item.rawCall = newRawCallText.toJSON(); charactersLeftToRemove -= ((originalRawCallTokensLength - newRawCallTextTokensLength) * estimatedCharactersPerToken); } } } } function removeHistoryThatLedToModelResponseAtIndex(index) { let removedItems = 0; for (let i = index - 1; i >= 0; i--) { const historyItem = res[i]; if (historyItem == null) continue; if (historyItem.type === "model") break; // stop removing history items if we reach another model response if (i === 0 && historyItem.type === "system") break; // keep the first system message if (historyItem.type === "user" || historyItem.type === "system") { const newText = truncateLlamaTextAndRoundToWords(LlamaText.fromJSON(historyItem.text), charactersLeftToRemove, undefined, false); const newTextString = newText.toString(); const historyItemString = LlamaText.fromJSON(historyItem.text).toString(); if (newText.values.length === 0) { res.splice(i, 1); i++; removedItems++; charactersLeftToRemove -= historyItemString.length; } else if (newTextString.length < historyItemString.length) { charactersLeftToRemove -= historyItemString.length - newTextString.length; if (historyItem.type === "user") historyItem.text = newText.toString(); else historyItem.text = newText.toJSON(); } } else { void historyItem; } } return removedItems; } function compressHistoryThatLedToModelResponseAtIndex(index, keepTokensCount = 0) { let removedItems = 0; let promptStartIndex = undefined; for (let i = index - 1; i >= 0; i--) { const historyItem = res[i]; if (historyItem == null) continue; if (historyItem.type === "model") { promptStartIndex = i + 1; break; } if (i === 0 && historyItem.type === "system") { promptStartIndex = i + 1; break; // keep the first system message } } if (promptStartIndex == null || promptStartIndex >= index) return 0; for (let i = promptStartIndex; i < index && charactersLeftToRemove > 0; i++) { const historyItem = res[i]; if (historyItem == null || historyItem.type !== "user") continue; let removeChars = Math.min(charactersLeftToRemove, historyItem.text.length); if (keepTokensCount > 0) { removeChars -= Math.floor(keepTokensCount * estimatedCharactersPerToken); if (removeChars < 0) removeChars = 0; keepTokensCount -= Math.min(keepTokensCount, Math.max(0, historyItem.text.length - removeChars) / estimatedCharactersPerToken); } const newText = truncateTextAndRoundToWords(historyItem.text, removeChars, undefined, false); if (newText.length === 0) { res.splice(i, 1); i--; index--; removedItems++; charactersLeftToRemove -= historyItem.text.length; } else { charactersLeftToRemove -= historyItem.text.length - newText.length; historyItem.text = newText; } } return removedItems; } function removeEmptySegmentsFromModelResponse(modelResponse) { const stack = []; for (let t = 0; t < modelResponse.length && charactersLeftToRemove > 0; t++) { const item = modelResponse[t]; const isLastItem = t === modelResponse.length - 1; if (!isChatModelResponseSegment(item)) continue; const type = item.segmentType; const topStack = stack.at(-1); if (topStack?.type === type) { if (item.ended && item.text === "" && topStack.canRemove) { modelResponse.splice(t, 1); t--; modelResponse.splice(topStack.startIndex, 1); t--; stack.pop(); } else if (!item.ended && item.text === "" && !isLastItem) { modelResponse.splice(t, 1); t--; } else if (!item.ended && item.text !== "") topStack.canRemove = false; else if (item.ended) stack.pop(); } else if (!item.ended) stack.push({ type, startIndex: t, canRemove: item.text === "" }); } } function compressFirstModelResponse() { for (let i = 0; i < res.length && charactersLeftToRemove > 0; i++) { const historyItem = res[i]; const isLastHistoryItem = i === res.length - 1; if (historyItem.type !== "model") continue; for (let t = 0; t < historyItem.response.length && charactersLeftToRemove > 0; t++) { const item = historyItem.response[t]; const isLastText = t === historyItem.response.length - 1; if (isLastHistoryItem && isLastText) continue; if (typeof item === "string") { const newText = truncateTextAndRoundToWords(item, charactersLeftToRemove, undefined, true); if (newText === "") { historyItem.response.splice(t, 1); t--; charactersLeftToRemove -= item.length; } else if (newText.length < item.length) { historyItem.response[t] = newText; charactersLeftToRemove -= item.length - newText.length; } } else if (isChatModelResponseFunctionCall(item)) { historyItem.response.splice(t, 1); t--; const functionCallAndResultTokenUsage = chatWrapper.generateFunctionCallsAndResults([item], true) .tokenize(tokenizer, "trimLeadingSpace").length; charactersLeftToRemove -= functionCallAndResultTokenUsage * estimatedCharactersPerToken; } else if (isChatModelResponseSegment(item)) { if (item.text !== "") { const newText = truncateTextAndRoundToWords(item.text, charactersLeftToRemove, undefined, true); if (newText === "" && item.ended) { const emptySegmentTokenUsage = chatWrapper.generateModelResponseText([{ ...item, text: "" }], true) .tokenize(tokenizer, "trimLeadingSpace").length; historyItem.response.splice(t, 1); t--; charactersLeftToRemove -= item.text.length + emptySegmentTokenUsage * estimatedCharactersPerToken; } else { charactersLeftToRemove -= item.text.length - newText.length; item.text = newText; } } } else void item; } removeEmptySegmentsFromModelResponse(historyItem.response); if (historyItem.response.length === 0) { // if the model response is removed from the history, // the things that led to it are not important anymore i -= removeHistoryThatLedToModelResponseAtIndex(i); res.splice(i, 1); i--; } } } function compressLastModelResponse(minCharactersToKeep = 60) { const lastHistoryItem = res[res.length - 1]; if (lastHistoryItem == null || lastHistoryItem.type !== "model") return; const lastResponseItem = lastHistoryItem.response[lastHistoryItem.response.length - 1]; if (lastResponseItem == null || typeof lastResponseItem !== "string") return; compressHistoryThatLedToModelResponseAtIndex(res.length - 1, maxTokensCount / 4); if (charactersLeftToRemove <= 0) return; const nextTextLength = Math.max(Math.min(lastResponseItem.length, minCharactersToKeep), lastResponseItem.length - charactersLeftToRemove); const charactersToRemoveFromText = lastResponseItem.length - nextTextLength; const newText = truncateTextAndRoundToWords(lastResponseItem, charactersToRemoveFromText, undefined, true); if (newText.length < lastResponseItem.length) { lastHistoryItem.response[lastHistoryItem.response.length - 1] = newText; charactersLeftToRemove -= lastResponseItem.length - newText.length; } if (charactersLeftToRemove <= 0) return; compressHistoryThatLedToModelResponseAtIndex(res.length - 1); } compressFunctionCalls(); if (charactersLeftToRemove <= 0) return res; compressFirstModelResponse(); if (charactersLeftToRemove <= 0) return res; compressLastModelResponse(); return res; } }); const newMetadata = { removedCharactersNumber: removedCharactersCount }; return { chatHistory: compressedChatHistory, metadata: newMetadata }; } function isCalculationMetadata(metadata) { return metadata != null && typeof metadata === "object" && typeof metadata.removedCharactersNumber === "number"; } //# sourceMappingURL=eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js.map