UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

388 lines 17.7 kB
import { splitText } from "lifecycle-utils"; import { JinjaTemplateChatWrapper } from "../generic/JinjaTemplateChatWrapper.js"; import { SpecialToken, LlamaText, SpecialTokensText } from "../../utils/LlamaText.js"; import { compareTokens } from "../../utils/compareTokens.js"; import { StopGenerationDetector } from "../../utils/StopGenerationDetector.js"; import { jsonDumps } from "./jsonDumps.js"; export function isJinjaTemplateEquivalentToSpecializedChatWrapper(jinjaTemplateWrapperOptions, specializedChatWrapper, tokenizer) { const getCheckChatHistories = (jinjaChatWrapper) => [ ...testChatHistories, ...((jinjaChatWrapper.usingJinjaFunctionCallTemplate || jinjaTemplateWrapperOptions.functionCallMessageTemplate === "auto") ? testChatHistoriesWithFunctionCalls : []) ]; const canTestMultipleConvertSystemMessagesToUserMessages = jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages == null || jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages === "auto"; try { const jinjaChatWrapper = new JinjaTemplateChatWrapper({ ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: canTestMultipleConvertSystemMessagesToUserMessages ? false : jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages, trimLeadingWhitespaceInResponses: false }); const checkChatHistories = getCheckChatHistories(jinjaChatWrapper); if (checkEquivalence(jinjaChatWrapper, specializedChatWrapper, checkChatHistories, tokenizer)) return true; } catch (err) { // Do nothing } try { const jinjaChatWrapperWithLeadingWhitespaceTrimming = new JinjaTemplateChatWrapper({ ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: canTestMultipleConvertSystemMessagesToUserMessages ? false : jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages, trimLeadingWhitespaceInResponses: true }); const checkChatHistories = getCheckChatHistories(jinjaChatWrapperWithLeadingWhitespaceTrimming); if (checkEquivalence(jinjaChatWrapperWithLeadingWhitespaceTrimming, specializedChatWrapper, checkChatHistories, tokenizer)) return true; } catch (err) { // Do nothing } if (!canTestMultipleConvertSystemMessagesToUserMessages) return false; const convertSystemMessagesToUserMessagesTemplate = "### System message\n\n{{message}}\n\n----"; try { const jinjaChatWrapper = new JinjaTemplateChatWrapper({ ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: { use: "always", format: convertSystemMessagesToUserMessagesTemplate }, trimLeadingWhitespaceInResponses: false }); const checkChatHistories = getCheckChatHistories(jinjaChatWrapper); const transformedCheckChatHistories = convertTestChatHistoriesSystemMessagesToUserMessages(checkChatHistories, convertSystemMessagesToUserMessagesTemplate); if (checkEquivalence(jinjaChatWrapper, specializedChatWrapper, transformedCheckChatHistories, tokenizer)) return true; } catch (err) { // Do nothing } try { const jinjaChatWrapperWithLeadingWhitespaceTrimming = new JinjaTemplateChatWrapper({ ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: { use: "always", format: convertSystemMessagesToUserMessagesTemplate }, trimLeadingWhitespaceInResponses: true }); const checkChatHistories = getCheckChatHistories(jinjaChatWrapperWithLeadingWhitespaceTrimming); const transformedCheckChatHistories = convertTestChatHistoriesSystemMessagesToUserMessages(checkChatHistories, convertSystemMessagesToUserMessagesTemplate); if (checkEquivalence(jinjaChatWrapperWithLeadingWhitespaceTrimming, specializedChatWrapper, transformedCheckChatHistories, tokenizer)) return true; } catch (err) { // Do nothing } return false; } function checkEquivalence(jinjaChatWrapper, specializedChatWrapper, testChatHistories, tokenizer) { for (const testChatHistory of testChatHistories) { const jinjaRes = jinjaChatWrapper.generateContextState({ chatHistory: testChatHistory }); jinjaRes.contextText = convertFunctionNameAndParamsToRegularText(jinjaRes.contextText, testChatHistory); const convertedSettings = convertChatWrapperSettingsToUseSpecialTokensText(specializedChatWrapper.settings); const originalSpecializedSettings = specializedChatWrapper.settings; if (convertedSettings != null) specializedChatWrapper.settings = convertedSettings; let specializedWrapperRes; try { specializedWrapperRes = specializedChatWrapper.generateContextState({ chatHistory: testChatHistory }); } finally { if (convertedSettings != null) specializedChatWrapper.settings = originalSpecializedSettings; } if (!compareContextTexts(jinjaRes.contextText, specializedWrapperRes.contextText, tokenizer)) return false; const jinjaHasAllSpecializedStopGenerationTriggers = jinjaRes.stopGenerationTriggers .every((trigger) => { return [trigger, trigger.trimEnd(), trigger.trimStart(), trigger.trimStart().trimEnd()].some((normalizedJinjaTrigger) => { if (normalizedJinjaTrigger.values.length === 0) return true; const foundSimilarTriggers = specializedWrapperRes.stopGenerationTriggers.some((specializedTrigger) => (normalizedJinjaTrigger.includes(specializedTrigger))); if (foundSimilarTriggers) return true; if (tokenizer != null) { const resolvedStopGenerationTrigger = StopGenerationDetector.resolveLlamaTextTrigger(normalizedJinjaTrigger, tokenizer); const foundSimilarOrShorterTokenizedTriggers = specializedWrapperRes.stopGenerationTriggers .some((specializedTrigger) => { const resolvedSpecializedTrigger = StopGenerationDetector.resolveLlamaTextTrigger(specializedTrigger, tokenizer); return resolvedSpecializedTrigger.every((item, index) => { const resolveTriggerItem = resolvedStopGenerationTrigger[index]; if (typeof item === "string" && typeof resolveTriggerItem === "string") return item === resolveTriggerItem; else if (typeof item === "string" || typeof resolveTriggerItem === "string" || resolveTriggerItem == null) return false; return compareTokens(item, resolveTriggerItem); }); }); if (foundSimilarOrShorterTokenizedTriggers) return true; } return false; }); }); if (!jinjaHasAllSpecializedStopGenerationTriggers) return false; } return true; } function compareContextTexts(text1, text2, tokenizer) { function compare(text1, text2) { if (LlamaText.compare(text1, text2)) return true; if (tokenizer != null) { const tokenizedText1 = text1.tokenize(tokenizer); const tokenizedText2 = text2.tokenize(tokenizer); if (tokenizedText1.length === tokenizedText2.length) return tokenizedText1.every((token, index) => compareTokens(token, tokenizedText2[index])); } return false; } const trimmedText1 = text1.trimEnd(); const trimmedText2 = text2.trimEnd(); const normalizedText1 = removeLeadingBos(trimmedText1); const normalizedText2 = removeLeadingBos(trimmedText2); const texts1 = (normalizedText1.values.length !== trimmedText1.values.length && tokenizer != null) ? [trimmedText1, normalizedText1] : [normalizedText1]; const texts2 = (normalizedText2.values.length !== trimmedText2.values.length && tokenizer != null) ? [trimmedText2, normalizedText2] : [normalizedText2]; return texts1.some((text1) => (texts2.some((text2) => (compare(text1, text2))))); } function convertTestChatHistoriesSystemMessagesToUserMessages(chatHistories, template) { return chatHistories .map((history) => (history .slice() .map((item, index, array) => { if (item.type === "system") { if (index === 0 && array.length > 1 && array[1].type === "user") { array[1] = { type: "user", text: LlamaText([ LlamaText.joinValues(LlamaText.fromJSON(item.text), template.split("{{message}}")), "\n\n", array[1].text ]).toString() }; return null; } return { type: "user", text: LlamaText.joinValues(LlamaText.fromJSON(item.text), template.split("{{message}}")).toString() }; } return item; }) .filter((item) => item != null))); } function convertChatWrapperSettingsToUseSpecialTokensText(settings) { if (settings?.functions == null) return null; function convertToSpecialTokensText(value, keepTexts) { if (value == null) return value; return LlamaText(LlamaText(value).values .map((item) => { if (typeof item !== "string") return item; if (keepTexts == null || keepTexts.length === 0) return new SpecialTokensText(item); return splitText(item, keepTexts).map((textPart) => { if (typeof textPart === "string") return new SpecialTokensText(textPart); return textPart.separator; }); })); } return { ...settings, functions: { ...settings.functions, call: { ...settings.functions.call, prefix: convertToSpecialTokensText(settings.functions.call.prefix), suffix: convertToSpecialTokensText(settings.functions.call.suffix), paramsPrefix: convertToSpecialTokensText(settings.functions.call.paramsPrefix) }, result: { ...settings.functions.result, prefix: convertToSpecialTokensText(settings.functions.result.prefix, ["{{functionName}}", "{{functionParams}}"]), suffix: convertToSpecialTokensText(settings.functions.result.suffix, ["{{functionName}}", "{{functionParams}}"]) }, parallelism: settings.functions.parallelism == null ? settings.functions.parallelism : { ...settings.functions.parallelism, call: { ...settings.functions.parallelism.call, sectionPrefix: convertToSpecialTokensText(settings.functions.parallelism.call.sectionPrefix), betweenCalls: convertToSpecialTokensText(settings.functions.parallelism.call.betweenCalls), sectionSuffix: convertToSpecialTokensText(settings.functions.parallelism.call.sectionSuffix) }, result: settings.functions.parallelism.result == null ? settings.functions.parallelism.result : { ...settings.functions.parallelism.result, sectionPrefix: convertToSpecialTokensText(settings.functions.parallelism.result.sectionPrefix), betweenResults: convertToSpecialTokensText(settings.functions.parallelism.result.betweenResults), sectionSuffix: convertToSpecialTokensText(settings.functions.parallelism.result.sectionSuffix) } } } }; } function convertFunctionNameAndParamsToRegularText(contextText, chatHistory) { const ensureRegularTextItems = new Set(); for (const item of chatHistory) { if (item.type !== "model") continue; for (const response of item.response) { if (typeof response === "string" || response.type !== "functionCall") continue; ensureRegularTextItems.add(response.name); if (response.params !== undefined && response.params !== "") ensureRegularTextItems.add(jsonDumps(response.params)); } } const ensureRegularTextItemsArray = [...ensureRegularTextItems]; return LlamaText(contextText.values.map((item) => { if (!(item instanceof SpecialTokensText)) return item; return splitText(item.value, ensureRegularTextItemsArray) .map((textPart) => { if (typeof textPart === "string") return new SpecialTokensText(textPart); return textPart.separator; }); })); } const testChatHistories = [ [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: [""] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }, { type: "user", text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: [""] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }, { type: "user", text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }] ]; const testChatHistoriesWithFunctionCalls = [ [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }, { type: "user", text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: [ "Result2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~", { type: "functionCall", name: "func1name", params: { param1: "value1" }, result: "func1result" }, "Result3 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" ] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }, { type: "user", text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: [ "Result2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~", { type: "functionCall", name: "func1name", params: { param1: "value1" }, result: "func1result" }, { type: "functionCall", name: "func2name", params: { param1: "value2" }, result: "func2result" }, "Result3 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" ] }] ]; function removeLeadingBos(llamaText) { if (llamaText.values.length === 0) return llamaText; const firstValue = llamaText.values[0]; if (firstValue instanceof SpecialToken && firstValue.value === "BOS") return LlamaText(llamaText.values.slice(1)); return llamaText; } //# sourceMappingURL=isJinjaTemplateEquivalentToSpecializedChatWrapper.js.map