UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

221 lines 10.1 kB
import { JinjaTemplateChatWrapper } from "../generic/JinjaTemplateChatWrapper.js"; import { SpecialToken, LlamaText } from "../../utils/LlamaText.js"; import { compareTokens } from "../../utils/compareTokens.js"; import { StopGenerationDetector } from "../../utils/StopGenerationDetector.js"; export function isJinjaTemplateEquivalentToSpecializedChatWrapper(jinjaTemplateWrapperOptions, specializedChatWrapper, tokenizer) { const canTestMultipleConvertSystemMessagesToUserMessages = jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages == null || jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages === "auto"; try { const jinjaChatWrapper = new JinjaTemplateChatWrapper({ ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: canTestMultipleConvertSystemMessagesToUserMessages ? false : jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages, trimLeadingWhitespaceInResponses: false }); if (checkEquivalence(jinjaChatWrapper, specializedChatWrapper, testChatHistories, tokenizer)) return true; } catch (err) { // Do nothing } try { const jinjaChatWrapperWithLeadingWhitespaceTrimming = new JinjaTemplateChatWrapper({ ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: canTestMultipleConvertSystemMessagesToUserMessages ? false : jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages, trimLeadingWhitespaceInResponses: true }); if (checkEquivalence(jinjaChatWrapperWithLeadingWhitespaceTrimming, specializedChatWrapper, testChatHistories, tokenizer)) return true; } catch (err) { // Do nothing } if (!canTestMultipleConvertSystemMessagesToUserMessages) return false; const convertSystemMessagesToUserMessagesTemplate = "### System message\n\n{{message}}\n\n----"; const transformedTestChatHistories = testChatHistories .map((history) => (history .slice() .map((item, index, array) => { if (item.type === "system") { if (index === 0 && array.length > 1 && array[1].type === "user") { array[1] = { type: "user", text: LlamaText([ LlamaText.joinValues(LlamaText.fromJSON(item.text), convertSystemMessagesToUserMessagesTemplate.split("{{message}}")), "\n\n", array[1].text ]).toString() }; return null; } return { type: "user", text: LlamaText.joinValues(LlamaText.fromJSON(item.text), convertSystemMessagesToUserMessagesTemplate.split("{{message}}")).toString() }; } return item; }) .filter((item) => item != null))); try { const jinjaChatWrapper = new JinjaTemplateChatWrapper({ ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: { use: "always", format: convertSystemMessagesToUserMessagesTemplate }, trimLeadingWhitespaceInResponses: false }); if (checkEquivalence(jinjaChatWrapper, specializedChatWrapper, transformedTestChatHistories, tokenizer)) return true; } catch (err) { // Do nothing } try { const jinjaChatWrapperWithLeadingWhitespaceTrimming = new JinjaTemplateChatWrapper({ ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: { use: "always", format: convertSystemMessagesToUserMessagesTemplate }, trimLeadingWhitespaceInResponses: true }); if (checkEquivalence(jinjaChatWrapperWithLeadingWhitespaceTrimming, specializedChatWrapper, transformedTestChatHistories, tokenizer)) return true; } catch (err) { // Do nothing } return false; } function checkEquivalence(jinjaChatWrapper, specializedChatWrapper, testChatHistories, tokenizer) { for (const testChatHistory of testChatHistories) { const jinjaRes = jinjaChatWrapper.generateContextState({ chatHistory: testChatHistory }); const specializedWrapperRes = specializedChatWrapper.generateContextState({ chatHistory: testChatHistory }); if (!compareContextTexts(jinjaRes.contextText, specializedWrapperRes.contextText, tokenizer)) return false; const jinjaHasAllSpecializedStopGenerationTriggers = jinjaRes.stopGenerationTriggers .every((trigger) => { return [trigger, trigger.trimEnd(), trigger.trimStart(), trigger.trimStart().trimEnd()].some((normalizedJinjaTrigger) => { if (normalizedJinjaTrigger.values.length === 0) return true; const foundSimilarTriggers = specializedWrapperRes.stopGenerationTriggers.some((specializedTrigger) => (normalizedJinjaTrigger.includes(specializedTrigger))); if (foundSimilarTriggers) return true; if (tokenizer != null) { const resolvedStopGenerationTrigger = StopGenerationDetector.resolveLlamaTextTrigger(normalizedJinjaTrigger, tokenizer); const foundSimilarOrShorterTokenizedTriggers = specializedWrapperRes.stopGenerationTriggers .some((specializedTrigger) => { const resolvedSpecializedTrigger = StopGenerationDetector.resolveLlamaTextTrigger(specializedTrigger, tokenizer); return resolvedSpecializedTrigger.every((item, index) => { const resolveTriggerItem = resolvedStopGenerationTrigger[index]; if (typeof item === "string" && typeof resolveTriggerItem === "string") return item === resolveTriggerItem; else if (typeof item === "string" || typeof resolveTriggerItem === "string" || resolveTriggerItem == null) return false; return compareTokens(item, resolveTriggerItem); }); }); if (foundSimilarOrShorterTokenizedTriggers) return true; } return false; }); }); if (!jinjaHasAllSpecializedStopGenerationTriggers) return false; } return true; } function compareContextTexts(text1, text2, tokenizer) { function compare(text1, text2) { if (LlamaText.compare(text1, text2)) return true; if (tokenizer != null) { const tokenizedText1 = text1.tokenize(tokenizer); const tokenizedText2 = text2.tokenize(tokenizer); if (tokenizedText1.length === tokenizedText2.length) return tokenizedText1.every((token, index) => compareTokens(token, tokenizedText2[index])); } return false; } const trimmedText1 = text1.trimEnd(); const trimmedText2 = text2.trimEnd(); const normalizedText1 = removeLeadingBos(trimmedText1); const normalizedText2 = removeLeadingBos(trimmedText2); const texts1 = (normalizedText1.values.length !== trimmedText1.values.length && tokenizer != null) ? [trimmedText1, normalizedText1] : [normalizedText1]; const texts2 = (normalizedText2.values.length !== trimmedText2.values.length && tokenizer != null) ? [trimmedText2, normalizedText2] : [normalizedText2]; return texts1.some((text1) => (texts2.some((text2) => (compare(text1, text2))))); } const testChatHistories = [ [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: [""] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }, { type: "user", text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: [""] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }, { type: "user", text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }] ]; function removeLeadingBos(llamaText) { if (llamaText.values.length === 0) return llamaText; const firstValue = llamaText.values[0]; if (firstValue instanceof SpecialToken && firstValue.value === "BOS") return LlamaText(llamaText.values.slice(1)); return llamaText; } //# sourceMappingURL=isJinjaTemplateEquivalentToSpecializedChatWrapper.js.map