UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

622 lines 29.8 kB
import { ChatWrapper } from "../ChatWrapper.js"; import { isChatModelResponseFunctionCall, isChatModelResponseSegment } from "../types.js"; import { LlamaText, SpecialToken, SpecialTokensText } from "../utils/LlamaText.js"; import { ChatModelFunctionsDocumentationGenerator } from "./utils/ChatModelFunctionsDocumentationGenerator.js"; import { jsonDumps } from "./utils/jsonDumps.js"; // source: https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v2.txt export class FunctionaryChatWrapper extends ChatWrapper { wrapperName = "Functionary"; variation; settings; constructor({ variation = "v3" } = {}) { super(); this.variation = variation; if (variation === "v3") this.settings = { ...ChatWrapper.defaultSettings, supportsSystemMessages: true, functions: { call: { optionalPrefixSpace: true, prefix: LlamaText(new SpecialTokensText(">>>")), paramsPrefix: LlamaText(new SpecialTokensText("\n")), suffix: "" }, result: { prefix: LlamaText([ new SpecialTokensText("<|start_header_id|>tool<|end_header_id|>\n\n") ]), suffix: LlamaText(new SpecialToken("EOT")) }, parallelism: { call: { sectionPrefix: "", betweenCalls: "", sectionSuffix: LlamaText(new SpecialToken("EOT")) }, result: { sectionPrefix: "", betweenResults: "", sectionSuffix: "" } } } }; else if (variation === "v2.llama3") this.settings = { ...ChatWrapper.defaultSettings, supportsSystemMessages: true, functions: { call: { optionalPrefixSpace: true, prefix: LlamaText(new SpecialTokensText("<|reserved_special_token_249|>")), paramsPrefix: LlamaText(new SpecialTokensText("\n")), suffix: "" }, result: { prefix: LlamaText([ new SpecialTokensText("<|start_header_id|>tool<|end_header_id|>\n\nname="), "{{functionName}}", new SpecialTokensText("\n") ]), suffix: LlamaText(new SpecialToken("EOT")) }, parallelism: { call: { sectionPrefix: "", betweenCalls: "", sectionSuffix: LlamaText(new SpecialToken("EOT")) }, result: { sectionPrefix: "", betweenResults: "", sectionSuffix: "" } } } }; else this.settings = { ...ChatWrapper.defaultSettings, supportsSystemMessages: true, functions: { call: { optionalPrefixSpace: true, prefix: LlamaText(new SpecialTokensText("\n<|from|>assistant\n<|recipient|>")), paramsPrefix: LlamaText(new SpecialTokensText("\n<|content|>")), suffix: "" }, result: { prefix: LlamaText([ new SpecialTokensText("\n<|from|>"), "{{functionName}}", new SpecialTokensText("\n<|recipient|>all\n<|content|>") ]), suffix: "" }, parallelism: { call: { sectionPrefix: "", betweenCalls: "\n", sectionSuffix: LlamaText(new SpecialTokensText("<|stop|>")) }, result: { sectionPrefix: "", betweenResults: "", sectionSuffix: "" } } } }; } generateContextState({ chatHistory, availableFunctions, documentFunctionParams }) { if (this.variation === "v3") return this._generateContextStateV3({ chatHistory, availableFunctions, documentFunctionParams }); else if (this.variation === "v2.llama3") return this._generateContextStateV2Llama3({ chatHistory, availableFunctions, documentFunctionParams }); return this._generateContextStateV2({ chatHistory, availableFunctions, documentFunctionParams }); } /** @internal */ _generateContextStateV3({ chatHistory, availableFunctions, documentFunctionParams }) { const hasFunctions = Object.keys(availableFunctions ?? {}).length > 0; const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, { documentParams: documentFunctionParams }); const contextText = LlamaText(historyWithFunctions.map((item, index) => { const isLastItem = index === historyWithFunctions.length - 1; if (item.type === "system") { if (item.text.length === 0) return ""; return LlamaText([ new SpecialTokensText("<|start_header_id|>system<|end_header_id|>\n\n"), LlamaText.fromJSON(item.text), new SpecialToken("EOT") ]); } else if (item.type === "user") { return LlamaText([ new SpecialTokensText("<|start_header_id|>user<|end_header_id|>\n\n"), item.text, new SpecialToken("EOT") ]); } else if (item.type === "model") { if (isLastItem && item.response.length === 0) return LlamaText([ new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n") ]); const res = []; const pendingFunctionCalls = []; const pendingFunctionResults = []; const addPendingFunctions = () => { if (pendingFunctionResults.length === 0) return; res.push(LlamaText(pendingFunctionCalls)); res.push(LlamaText(new SpecialToken("EOT"))); res.push(LlamaText(pendingFunctionResults)); pendingFunctionResults.length = 0; }; const simplifiedResponse = convertModelResponseToLamaTextAndFunctionCalls(item.response, this); for (let index = 0; index < simplifiedResponse.length; index++) { const response = simplifiedResponse[index]; const isLastResponse = index === simplifiedResponse.length - 1; if (response == null) continue; if (LlamaText.isLlamaText(response)) { addPendingFunctions(); res.push(LlamaText([ new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n"), (isLastResponse && response.values.length === 0) ? hasFunctions ? LlamaText(new SpecialTokensText(">>>")) : LlamaText(new SpecialTokensText(">>>all\n")) : LlamaText([ new SpecialTokensText(">>>all\n"), response, (!isLastResponse || isLastItem) ? LlamaText([]) : new SpecialToken("EOT") ]) ])); } else if (isChatModelResponseFunctionCall(response)) { if (response.startsNewChunk) addPendingFunctions(); pendingFunctionCalls.push(response.rawCall != null ? LlamaText.fromJSON(response.rawCall) : LlamaText([ new SpecialTokensText(">>>"), response.name, new SpecialTokensText("\n"), response.params === undefined ? "" : jsonDumps(response.params) ])); pendingFunctionResults.push(LlamaText([ new SpecialTokensText("<|start_header_id|>tool<|end_header_id|>\n\n"), response.result === undefined ? "" // "void" : jsonDumps(response.result), new SpecialToken("EOT") ])); } else void response; } addPendingFunctions(); if (isLastItem && (res.length === 0 || typeof item.response[item.response.length - 1] !== "string")) res.push(hasFunctions ? LlamaText(new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n")) : LlamaText(new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\n"))); return LlamaText(res); } void item; return ""; })); const lastItem = historyWithFunctions.at(-1); if (!hasFunctions || (lastItem?.type === "model" && lastItem.response.length > 0 && typeof lastItem.response.at(-1) === "string" && lastItem.response.at(-1) !== "")) { return { contextText, stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialToken("EOT")), LlamaText(new SpecialTokensText("<|eot_id|>")), LlamaText(new SpecialTokensText("<|end_of_text|>")), LlamaText("<|eot_id|>"), LlamaText("<|end_of_text|>") ] }; } const textResponseStart = [ LlamaText(new SpecialTokensText(">>>all\n")), LlamaText(">>>all\n") ]; return { contextText, stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialToken("EOT")), LlamaText(new SpecialTokensText("<|eot_id|>")), LlamaText(new SpecialTokensText("<|end_of_text|>")), LlamaText("<|eot_id|>"), LlamaText("<|end_of_text|>") ], ignoreStartText: textResponseStart, functionCall: { initiallyEngaged: true, disengageInitiallyEngaged: textResponseStart } }; } /** @internal */ _generateContextStateV2Llama3({ chatHistory, availableFunctions, documentFunctionParams }) { const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, { documentParams: documentFunctionParams }); const contextText = LlamaText(new SpecialToken("BOS"), historyWithFunctions.map((item, index) => { const isLastItem = index === historyWithFunctions.length - 1; if (item.type === "system") { if (item.text.length === 0) return ""; return LlamaText([ new SpecialTokensText("<|start_header_id|>system<|end_header_id|>\n\n"), LlamaText.fromJSON(item.text), new SpecialToken("EOT") ]); } else if (item.type === "user") { return LlamaText([ new SpecialTokensText("<|start_header_id|>user<|end_header_id|>\n\n"), item.text, new SpecialToken("EOT") ]); } else if (item.type === "model") { if (isLastItem && item.response.length === 0) return LlamaText([ new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n") ]); const res = []; const pendingFunctionCalls = []; const pendingFunctionResults = []; const addPendingFunctions = () => { if (pendingFunctionResults.length === 0) return; res.push(LlamaText(pendingFunctionCalls)); res.push(LlamaText(new SpecialToken("EOT"))); res.push(LlamaText(pendingFunctionResults)); pendingFunctionResults.length = 0; }; const simplifiedResponse = convertModelResponseToLamaTextAndFunctionCalls(item.response, this); for (let index = 0; index < simplifiedResponse.length; index++) { const response = simplifiedResponse[index]; const isLastResponse = index === simplifiedResponse.length - 1; if (response == null) continue; if (LlamaText.isLlamaText(response)) { addPendingFunctions(); res.push(LlamaText([ new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n"), response, (isLastItem && isLastResponse) ? LlamaText([]) : new SpecialToken("EOT") ])); } else if (isChatModelResponseFunctionCall(response)) { if (response.startsNewChunk) addPendingFunctions(); pendingFunctionCalls.push(response.rawCall != null ? LlamaText.fromJSON(response.rawCall) : LlamaText([ new SpecialTokensText("<|reserved_special_token_249|>"), response.name, new SpecialTokensText("\n"), response.params === undefined ? "" : jsonDumps(response.params) ])); pendingFunctionResults.push(LlamaText([ new SpecialTokensText("<|start_header_id|>tool<|end_header_id|>\n\nname="), response.name, new SpecialTokensText("\n"), response.result === undefined ? "" // "void" : jsonDumps(response.result), new SpecialToken("EOT") ])); } else void response; } addPendingFunctions(); if (isLastItem && (res.length === 0 || typeof item.response[item.response.length - 1] !== "string")) res.push(LlamaText([ new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n") ])); return LlamaText(res); } void item; return ""; })); return { contextText, stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialToken("EOT")), LlamaText(new SpecialTokensText("<|eot_id|>")), LlamaText(new SpecialTokensText("<|end_of_text|>")), LlamaText("<|eot_id|>"), LlamaText("<|end_of_text|>") ] }; } /** @internal */ _generateContextStateV2({ chatHistory, availableFunctions, documentFunctionParams }) { const hasFunctions = Object.keys(availableFunctions ?? {}).length > 0; const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, { documentParams: documentFunctionParams }); const contextText = LlamaText(new SpecialToken("BOS"), historyWithFunctions.map((item, index) => { const isFirstItem = index === 0; const isLastItem = index === historyWithFunctions.length - 1; if (item.type === "system") { if (item.text.length === 0) return ""; return LlamaText([ isFirstItem ? LlamaText([]) : new SpecialTokensText("\n"), new SpecialTokensText("<|from|>system\n"), new SpecialTokensText("<|recipient|>all\n"), new SpecialTokensText("<|content|>"), LlamaText.fromJSON(item.text) ]); } else if (item.type === "user") { return LlamaText([ isFirstItem ? LlamaText([]) : new SpecialTokensText("\n"), new SpecialTokensText("<|from|>user\n"), new SpecialTokensText("<|recipient|>all\n"), new SpecialTokensText("<|content|>"), item.text ]); } else if (item.type === "model") { if (isLastItem && item.response.length === 0 && !hasFunctions) return LlamaText([ isFirstItem ? LlamaText([]) : new SpecialTokensText("\n"), new SpecialTokensText("<|from|>assistant\n"), new SpecialTokensText("<|recipient|>all\n"), new SpecialTokensText("<|content|>") ]); const res = []; const pendingFunctionCalls = []; const pendingFunctionResults = []; const addPendingFunctions = () => { if (pendingFunctionResults.length === 0) return; res.push(LlamaText(pendingFunctionCalls)); res.push(LlamaText(new SpecialTokensText("<|stop|>"))); res.push(LlamaText(pendingFunctionResults)); pendingFunctionResults.length = 0; }; const simplifiedResponse = convertModelResponseToLamaTextAndFunctionCalls(item.response, this); for (let index = 0; index < simplifiedResponse.length; index++) { const response = simplifiedResponse[index]; const isFirstResponse = index === 0; if (response == null) continue; if (LlamaText.isLlamaText(response)) { addPendingFunctions(); res.push(LlamaText([ (isFirstItem && isFirstResponse) ? LlamaText([]) : new SpecialTokensText("\n"), new SpecialTokensText("<|from|>assistant\n"), new SpecialTokensText("<|recipient|>all\n"), new SpecialTokensText("<|content|>"), response ])); } else if (isChatModelResponseFunctionCall(response)) { pendingFunctionCalls.push(response.rawCall != null ? LlamaText.fromJSON(response.rawCall) : LlamaText([ (isFirstItem && isFirstResponse) ? LlamaText([]) : new SpecialTokensText("\n"), new SpecialTokensText("<|from|>assistant\n"), new SpecialTokensText("<|recipient|>"), response.name, new SpecialTokensText("\n"), new SpecialTokensText("<|content|>"), response.params === undefined ? "" : jsonDumps(response.params) ])); pendingFunctionResults.push(LlamaText([ new SpecialTokensText("\n"), new SpecialTokensText("<|from|>"), response.name, new SpecialTokensText("\n"), new SpecialTokensText("<|recipient|>all\n"), new SpecialTokensText("<|content|>"), response.result === undefined ? "" // "void" : jsonDumps(response.result) ])); } else void response; } addPendingFunctions(); if (res.length === 0) { if (isLastItem) { if (!hasFunctions) res.push(LlamaText([ isFirstItem ? LlamaText([]) : new SpecialTokensText("\n"), new SpecialTokensText("<|from|>assistant\n"), new SpecialTokensText("<|recipient|>all\n"), new SpecialTokensText("<|content|>") ])); } else res.push(LlamaText([ isFirstItem ? LlamaText([]) : new SpecialTokensText("\n"), new SpecialTokensText("<|from|>assistant\n"), new SpecialTokensText("<|recipient|>all\n"), new SpecialTokensText("<|content|>") ])); } else if (isLastItem && typeof item.response[item.response.length - 1] !== "string") { if (!hasFunctions) res.push(LlamaText([ isFirstItem ? LlamaText([]) : new SpecialTokensText("\n"), new SpecialTokensText("<|from|>assistant\n"), new SpecialTokensText("<|recipient|>all\n"), new SpecialTokensText("<|content|>") ])); } if (!isLastItem) res.push(LlamaText(new SpecialTokensText("<|stop|>"))); return LlamaText(res); } void item; return ""; })); if (!hasFunctions) { return { contextText, stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialTokensText("<|stop|>")), LlamaText(" <|stop|>"), LlamaText("<|stop|>"), LlamaText("\n<|from|>user"), LlamaText("\n<|from|>assistant"), LlamaText("\n<|from|>system"), LlamaText(new SpecialTokensText(" <|stop|>")), LlamaText(new SpecialTokensText("<|stop|>")), LlamaText(new SpecialTokensText("\n<|from|>user")), LlamaText(new SpecialTokensText("\n<|from|>assistant")), LlamaText(new SpecialTokensText("\n<|from|>system")) ] }; } const textResponseStart = [ "\n", "\n\n", " \n", " \n\n" ].flatMap((prefix) => [ LlamaText(new SpecialTokensText(prefix + "<|from|>assistant\n<|recipient|>all\n<|content|>")), LlamaText(prefix + "<|from|>assistant\n<|recipient|>all\n<|content|>") ]); return { contextText, stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialTokensText("<|stop|>")), LlamaText(" <|stop|>"), LlamaText("<|stop|>"), LlamaText("\n<|from|>user"), LlamaText(new SpecialTokensText(" <|stop|>")), LlamaText(new SpecialTokensText("<|stop|>")), LlamaText(new SpecialTokensText("\n<|from|>user")) ], ignoreStartText: textResponseStart, functionCall: { initiallyEngaged: true, disengageInitiallyEngaged: textResponseStart } }; } generateAvailableFunctionsSystemText(availableFunctions, { documentParams = true }) { const functionsDocumentationGenerator = new ChatModelFunctionsDocumentationGenerator(availableFunctions); if (!functionsDocumentationGenerator.hasAnyFunctions) return LlamaText([]); const availableFunctionNames = Object.keys(availableFunctions ?? {}); if (availableFunctionNames.length === 0) return LlamaText([]); if (this.variation === "v3") { return LlamaText.joinValues("\n", [ "You are capable of executing available function(s) if required.", "Only execute function(s) when absolutely necessary.", "Ask for the required input to:recipient==all", "Use JSON for function arguments.", "Respond in this format:", ">>>${recipient}", "${content}", "Available functions:", "// Supported function definitions that should be called when necessary.", "namespace functions {", "", functionsDocumentationGenerator.getTypeScriptFunctionTypes({ documentParams, reservedFunctionNames: ["all"] }), "", "} // namespace functions" ]); } return LlamaText.joinValues("\n", [ "// Supported function definitions that should be called when necessary.", "namespace functions {", "", functionsDocumentationGenerator.getTypeScriptFunctionTypes({ documentParams, reservedFunctionNames: ["all"] }), "", "} // namespace functions" ]); } addAvailableFunctionsSystemMessageToHistory(history, availableFunctions, { documentParams = true } = {}) { const availableFunctionNames = Object.keys(availableFunctions ?? {}); if (availableFunctions == null || availableFunctionNames.length === 0) return history; const res = history.slice(); const firstSystemMessageIndex = res.findIndex((item) => item.type === "system"); res.splice(Math.max(0, firstSystemMessageIndex), 0, { type: "system", text: this.generateAvailableFunctionsSystemText(availableFunctions, { documentParams }).toJSON() }, { type: "system", text: "The assistant calls functions with appropriate input when necessary. The assistant writes <|stop|> when finished answering." }); return res; } /** @internal */ static _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { return [ { variation: "v3" }, { variation: "v2.llama3" }, { variation: "v2" } ]; } } function convertModelResponseToLamaTextAndFunctionCalls(modelResponse, chatWrapper) { const pendingItems = []; const res = []; function pushPendingItems() { if (pendingItems.length === 0) return; res.push(chatWrapper.generateModelResponseText(pendingItems)); pendingItems.length = 0; } for (const item of modelResponse) { if (typeof item === "string" || isChatModelResponseSegment(item)) pendingItems.push(item); else { pushPendingItems(); res.push(item); } } pushPendingItems(); return res; } //# sourceMappingURL=FunctionaryChatWrapper.js.map