UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

415 lines 18.1 kB
import { Template } from "@huggingface/jinja"; import { splitText } from "lifecycle-utils"; import { SpecialToken, LlamaText, SpecialTokensText } from "../../utils/LlamaText.js"; import { ChatWrapper } from "../../ChatWrapper.js"; import { parseFunctionCallMessageTemplate } from "./utils/chatHistoryFunctionCallMessageTemplate.js"; import { templateSegmentOptionsToChatWrapperSettings } from "./utils/templateSegmentOptionsToChatWrapperSettings.js"; const defaultConvertUnsupportedSystemMessagesToUserMessagesFormat = { format: "### System message\n\n{{message}}\n\n----" }; /** * A chat wrapper based on a Jinja template. * Useful for using the original model's Jinja template as-is without any additional conversion work to chat with a model. * * If you want to create a new chat wrapper from scratch, using this chat wrapper is not recommended, and instead you better inherit * from the `ChatWrapper` class and implement a custom chat wrapper of your own in TypeScript. * * For a simpler way to create a chat wrapper, see the `TemplateChatWrapper` class. * @example * <span v-pre> * * ```ts * import {JinjaTemplateChatWrapper} from "node-llama-cpp"; * * const chatWrapper = new JinjaTemplateChatWrapper({ * template: "<Jinja template here>", * // functionCallMessageTemplate: { // optional * // call: "[[call: {{functionName}}({{functionParams}})]]", * // result: " [[result: {{functionCallResult}}]]" * // }, * // segments: { * // thoughtTemplate: "<think>{{content}}</think>", * // reopenThoughtAfterFunctionCalls: true * // } * }); * ``` * * </span> */ export class JinjaTemplateChatWrapper extends ChatWrapper { wrapperName = "JinjaTemplate"; settings; template; modelRoleName; userRoleName; systemRoleName; convertUnsupportedSystemMessagesToUserMessages; joinAdjacentMessagesOfTheSameType; trimLeadingWhitespaceInResponses; additionalRenderParameters; /** @internal */ _jinjaTemplate; /** * @param options */ constructor({ template, modelRoleName = "assistant", userRoleName = "user", systemRoleName = "system", convertUnsupportedSystemMessagesToUserMessages = defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, functionCallMessageTemplate, joinAdjacentMessagesOfTheSameType = true, trimLeadingWhitespaceInResponses = true, additionalRenderParameters, segments }) { super(); if (template == null) throw new Error("template cannot be null"); this.template = template; this.modelRoleName = modelRoleName; this.userRoleName = userRoleName; this.systemRoleName = systemRoleName; this.convertUnsupportedSystemMessagesToUserMessages = resolveConvertUnsupportedSystemMessagesToUserMessagesOption(convertUnsupportedSystemMessagesToUserMessages); this.joinAdjacentMessagesOfTheSameType = joinAdjacentMessagesOfTheSameType; this.trimLeadingWhitespaceInResponses = trimLeadingWhitespaceInResponses; this.additionalRenderParameters = additionalRenderParameters; this.settings = { ...ChatWrapper.defaultSettings, functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSettings.functions, segments: templateSegmentOptionsToChatWrapperSettings(segments) }; if (this.convertUnsupportedSystemMessagesToUserMessages != null && !this.convertUnsupportedSystemMessagesToUserMessages.format.includes("{{message}}")) throw new Error('convertUnsupportedSystemMessagesToUserMessages format must include "{{message}}"'); this._jinjaTemplate = new Template(this.template); const { supportsSystemMessages } = this._runSanityTest(); this.settings = { ...this.settings, supportsSystemMessages }; } generateContextState({ chatHistory, availableFunctions, documentFunctionParams }) { const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, { documentParams: documentFunctionParams }); if (this.convertUnsupportedSystemMessagesToUserMessages == null) { return this._generateContextText(historyWithFunctions, { convertSystemMessagesToUserMessagesFormat: undefined }); } else if (this.convertUnsupportedSystemMessagesToUserMessages.use === "always") { return this._generateContextText(historyWithFunctions, { convertSystemMessagesToUserMessagesFormat: this.convertUnsupportedSystemMessagesToUserMessages.format }); } try { return this._generateContextText(historyWithFunctions, { convertSystemMessagesToUserMessagesFormat: undefined }); } catch (error) { return this._generateContextText(historyWithFunctions, { convertSystemMessagesToUserMessagesFormat: this.convertUnsupportedSystemMessagesToUserMessages.format }); } } /** @internal */ _generateContextText(history, { convertSystemMessagesToUserMessagesFormat }) { let transformedSystemMessagesToUserMessages = false; const transformedHistory = convertSystemMessagesToUserMessagesFormat == null ? history : history.map((item) => { if (item.type === "system") { transformedSystemMessagesToUserMessages = true; return { type: "user", text: LlamaText.joinValues(LlamaText.fromJSON(item.text), convertSystemMessagesToUserMessagesFormat.split("{{message}}")).toString() }; } return item; }); const resultItems = []; const currentTexts = []; let currentAggregateFocus = null; function flush() { if (currentTexts.length > 0 && currentAggregateFocus != null) resultItems.push({ role: currentAggregateFocus, content: LlamaText.joinValues("\n\n", currentTexts) }); currentTexts.length = 0; } for (const item of transformedHistory) { if (item.type === "system") { if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "system") flush(); currentAggregateFocus = "system"; currentTexts.push(LlamaText.fromJSON(item.text)); } else if (item.type === "user") { if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "user") flush(); currentAggregateFocus = "user"; currentTexts.push(LlamaText(item.text)); } else if (item.type === "model") { if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "model") flush(); currentAggregateFocus = "model"; currentTexts.push(this.generateModelResponseText(item.response)); } else void item; } const lastItemIsModelMessage = currentAggregateFocus === "model"; flush(); const idsGenerator = new UniqueTemplateId(this.template + this.modelRoleName + this.userRoleName + this.systemRoleName + (convertSystemMessagesToUserMessagesFormat ?? "") + resultItems.map(({ content }) => content.toString()).join("\n\n")); const jinjaItems = []; const jinjaRoleMap = { system: this.systemRoleName, user: this.userRoleName, model: this.modelRoleName }; const idToContent = new Map(); const modelMessageIds = new Set(); const messageIds = new Set(); for (const resultItem of resultItems) { const id = idsGenerator.generateId(); messageIds.add(id); idToContent.set(id, resultItem.content); jinjaItems.push({ role: jinjaRoleMap[resultItem.role], content: id }); if (resultItem.role === "model") modelMessageIds.add(id); } const bosTokenId = idsGenerator.generateId(); const eosTokenId = idsGenerator.generateId(); const eotTokenId = idsGenerator.generateId(); idToContent.set(bosTokenId, new SpecialToken("BOS")); idToContent.set(eosTokenId, new SpecialToken("EOS")); idToContent.set(eotTokenId, new SpecialToken("EOT")); function tryOptions(options) { for (let i = 0; i < options.length; i++) { if (i === options.length - 1) return options[i](); try { return options[i](); } catch (err) { // do nothing } } throw new Error("All options failed"); } const renderJinjaText = () => { return tryOptions([ () => this._jinjaTemplate.render({ ...(this.additionalRenderParameters == null ? {} : structuredClone(this.additionalRenderParameters)), messages: jinjaItems, "bos_token": bosTokenId, "eos_token": eosTokenId, "eot_token": eotTokenId }), () => this._jinjaTemplate.render({ ...(this.additionalRenderParameters == null ? {} : structuredClone(this.additionalRenderParameters)), messages: jinjaItems, "bos_token": bosTokenId, "eos_token": eosTokenId, "eot_token": eotTokenId, "add_generation_prompt": true }) ]); }; const validateThatAllMessageIdsAreUsed = (parts) => { const messageIdsLeft = new Set(messageIds); for (const part of parts) { if (typeof part === "string") continue; messageIdsLeft.delete(part.separator); } if (messageIdsLeft.size !== 0) throw new Error("Some input messages are not present in the generated Jinja template output"); }; const renderJinjaAndSplitIntoParts = () => { const splitJinjaParts = splitText(renderJinjaText(), [...idToContent.keys()]); if (lastItemIsModelMessage) { let lastModelResponseIndex = -1; for (let i = splitJinjaParts.length - 1; i >= 0; i--) { const part = splitJinjaParts[i]; if (part == null || typeof part === "string") continue; if (modelMessageIds.has(part.separator)) { lastModelResponseIndex = i; break; } else if (messageIds.has(part.separator)) { validateThatAllMessageIdsAreUsed(splitJinjaParts); throw new Error("Last message was expected to be a model message, but it was not"); } } if (lastModelResponseIndex < 0) { validateThatAllMessageIdsAreUsed(splitJinjaParts); throw new Error("A model message was expected to be the last message, but it was not found"); } return { splitJinjaParts: splitJinjaParts.slice(0, lastModelResponseIndex + 1), stopGenerationJinjaParts: splitJinjaParts.slice(lastModelResponseIndex + 1) }; } return { splitJinjaParts, stopGenerationJinjaParts: [] }; }; const { splitJinjaParts, stopGenerationJinjaParts } = renderJinjaAndSplitIntoParts(); const messageIdsLeftToProcess = new Set(messageIds); const contextText = LlamaText(splitJinjaParts.map((part) => { if (typeof part === "string") return new SpecialTokensText(part); // things that are not message content can be tokenized with special tokens const message = idToContent.get(part.separator); if (message == null) throw new Error(`Message with id "${part.separator}" not found`); messageIdsLeftToProcess.delete(part.separator); return message; })); if (messageIdsLeftToProcess.size !== 0) throw new Error("Some input messages are not present in the generated Jinja template output"); return { contextText, ignoreStartText: !this.trimLeadingWhitespaceInResponses ? [] : [ // ignore up to 4 leading spaces ...Array(4).fill(0) .map((_, index) => LlamaText(" ".repeat(index + 1))), LlamaText("\t"), LlamaText("\t\t"), LlamaText("\t "), LlamaText(" \t") ], stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), ...(stopGenerationJinjaParts.length === 0 ? [] : [ LlamaText(stopGenerationJinjaParts.map((part) => { if (typeof part === "string") return new SpecialTokensText(part); const message = idToContent.get(part.separator); if (message == null) throw new Error(`Message with id "${part.separator}" not found`); return message; })) ]) ], transformedSystemMessagesToUserMessages }; } /** * Validate that this Jinja template can be rendered * @internal */ _runSanityTest() { try { let supportsSystemMessages = true; for (const chatHistory of chatHistoriesForSanityTest) { const { transformedSystemMessagesToUserMessages } = this.generateContextState({ chatHistory }); if (transformedSystemMessagesToUserMessages) supportsSystemMessages = false; } return { supportsSystemMessages }; } catch (err) { throw new Error("The provided Jinja template failed the sanity test: " + String(err) + ". Inspect the Jinja template to find out what went wrong"); } } } class UniqueTemplateId { antiText; _ids = new Set(); constructor(antiText) { this.antiText = antiText; } generateId() { let id; do { id = "W" + (Math.random() .toString(36) .slice(2)) + "W"; } while (this._ids.has(id) || this.antiText.includes(id)); this._ids.add(id); return id; } removeId(id) { this._ids.delete(id); } } function resolveConvertUnsupportedSystemMessagesToUserMessagesOption(convertUnsupportedSystemMessagesToUserMessages) { if (convertUnsupportedSystemMessagesToUserMessages === false) return undefined; if (convertUnsupportedSystemMessagesToUserMessages === true) return { ...defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, use: "always" }; if (convertUnsupportedSystemMessagesToUserMessages === "auto") return { ...defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, use: "ifNeeded" }; if (typeof convertUnsupportedSystemMessagesToUserMessages === "object") return { ...convertUnsupportedSystemMessagesToUserMessages, use: convertUnsupportedSystemMessagesToUserMessages.use ?? "ifNeeded" }; return { ...defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, use: "ifNeeded" }; } const chatHistoriesForSanityTest = [ [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: [""] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }, { type: "user", text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: [""] }], [{ type: "system", text: "System message ~!@#$%^&*()\n*" }, { type: "user", text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }, { type: "user", text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" }, { type: "model", response: ["Result2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] }] ]; //# sourceMappingURL=JinjaTemplateChatWrapper.js.map