node-llama-cpp
Version:
Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level
212 lines • 9.71 kB
JavaScript
import { SpecialToken, LlamaText, SpecialTokensText } from "../../utils/LlamaText.js";
import { ChatWrapper } from "../../ChatWrapper.js";
import { parseTextTemplate } from "../../utils/parseTextTemplate.js";
import { parseFunctionCallMessageTemplate } from "./utils/chatHistoryFunctionCallMessageTemplate.js";
import { templateSegmentOptionsToChatWrapperSettings } from "./utils/templateSegmentOptionsToChatWrapperSettings.js";
/**
* A chat wrapper based on a simple template.
* @example
* <span v-pre>
*
* ```ts
* import {TemplateChatWrapper} from "node-llama-cpp";
*
* const chatWrapper = new TemplateChatWrapper({
* template: "{{systemPrompt}}\n{{history}}model: {{completion}}\nuser: ",
* historyTemplate: {
* system: "system: {{message}}\n",
* user: "user: {{message}}\n",
* model: "model: {{message}}\n"
* },
* // functionCallMessageTemplate: { // optional
* // call: "[[call: {{functionName}}({{functionParams}})]]",
* // result: " [[result: {{functionCallResult}}]]"
* // },
* // segments: {
* // thoughtTemplate: "<think>{{content}}</think>",
* // reopenThoughtAfterFunctionCalls: true
* // }
* });
* ```
*
* </span>
*
* **<span v-pre>`{{systemPrompt}}`</span>** is optional and is replaced with the first system message
* (when is does, that system message is not included in the history).
*
* **<span v-pre>`{{history}}`</span>** is replaced with the chat history.
* Each message in the chat history is converted using the template passed to `historyTemplate` for the message role,
* and all messages are joined together.
*
* **<span v-pre>`{{completion}}`</span>** is where the model's response is generated.
* The text that comes after <span v-pre>`{{completion}}`</span> is used to determine when the model has finished generating the response,
* and thus is mandatory.
*
* **`functionCallMessageTemplate`** is used to specify the format in which functions can be called by the model and
* how their results are fed to the model after the function call.
*
* **`segments`** is used to specify the format of the segments generated by the model (like thought segments).
*/
export class TemplateChatWrapper extends ChatWrapper {
wrapperName = "Template";
settings;
template;
historyTemplate;
joinAdjacentMessagesOfTheSameType;
/** @internal */ _parsedChatTemplate;
/** @internal */ _parsedChatHistoryTemplate;
constructor({ template, historyTemplate, functionCallMessageTemplate, joinAdjacentMessagesOfTheSameType = true, segments }) {
super();
if (template == null || historyTemplate == null)
throw new Error("Template chat wrapper settings must have a template and historyTemplate.");
if (historyTemplate.system == null || historyTemplate.user == null || historyTemplate.model == null)
throw new Error("Template chat wrapper historyTemplate must have system, user, and model templates.");
this.template = template;
this.historyTemplate = historyTemplate;
this.joinAdjacentMessagesOfTheSameType = joinAdjacentMessagesOfTheSameType;
this._parsedChatTemplate = parseChatTemplate(template);
this._parsedChatHistoryTemplate = {
system: parseChatHistoryTemplate(historyTemplate.system),
user: parseChatHistoryTemplate(historyTemplate.user),
model: parseChatHistoryTemplate(historyTemplate.model)
};
this.settings = {
...ChatWrapper.defaultSettings,
functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSettings.functions,
segments: templateSegmentOptionsToChatWrapperSettings(segments)
};
}
generateContextState({ chatHistory, availableFunctions, documentFunctionParams }) {
const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, {
documentParams: documentFunctionParams
});
const resultItems = [];
const systemTexts = [];
const userTexts = [];
const modelTexts = [];
let currentAggregateFocus = null;
function flush() {
if (systemTexts.length > 0 || userTexts.length > 0 || modelTexts.length > 0)
resultItems.push({
system: LlamaText.joinValues("\n\n", systemTexts),
user: LlamaText.joinValues("\n\n", userTexts),
model: LlamaText.joinValues("\n\n", modelTexts)
});
systemTexts.length = 0;
userTexts.length = 0;
modelTexts.length = 0;
}
for (const item of historyWithFunctions) {
if (item.type === "system") {
if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "system")
flush();
currentAggregateFocus = "system";
systemTexts.push(LlamaText.fromJSON(item.text));
}
else if (item.type === "user") {
if (!this.joinAdjacentMessagesOfTheSameType || (currentAggregateFocus !== "system" && currentAggregateFocus !== "user"))
flush();
currentAggregateFocus = "user";
userTexts.push(LlamaText(item.text));
}
else if (item.type === "model") {
if (!this.joinAdjacentMessagesOfTheSameType)
flush();
currentAggregateFocus = "model";
modelTexts.push(this.generateModelResponseText(item.response));
}
else
void item;
}
flush();
const getHistoryItem = (role, text, prefix) => {
const { messagePrefix, messageSuffix } = this._parsedChatHistoryTemplate[role];
return LlamaText([
new SpecialTokensText((prefix ?? "") + messagePrefix),
text,
new SpecialTokensText(messageSuffix)
]);
};
const contextText = LlamaText(resultItems.map(({ system, user, model }, index) => {
const isFirstItem = index === 0;
const isLastItem = index === resultItems.length - 1;
const res = LlamaText([
isFirstItem
? system.values.length === 0
? new SpecialTokensText((this._parsedChatTemplate.systemPromptPrefix ?? "") + this._parsedChatTemplate.historyPrefix)
: this._parsedChatTemplate.systemPromptPrefix != null
? LlamaText([
new SpecialTokensText(this._parsedChatTemplate.systemPromptPrefix),
system,
new SpecialTokensText(this._parsedChatTemplate.historyPrefix)
])
: getHistoryItem("system", system, this._parsedChatTemplate.historyPrefix)
: system.values.length === 0
? LlamaText([])
: getHistoryItem("system", system),
user.values.length === 0
? LlamaText([])
: getHistoryItem("user", user),
model.values.length === 0
? LlamaText([])
: !isLastItem
? getHistoryItem("model", model)
: LlamaText([
new SpecialTokensText(this._parsedChatTemplate.completionPrefix),
model
])
]);
return LlamaText(res.values.reduce((res, value) => {
if (value instanceof SpecialTokensText) {
const lastItem = res[res.length - 1];
if (lastItem == null || !(lastItem instanceof SpecialTokensText))
return res.concat([value]);
return res.slice(0, -1).concat([
new SpecialTokensText(lastItem.value + value.value)
]);
}
return res.concat([value]);
}, []));
}));
return {
contextText,
stopGenerationTriggers: [
LlamaText(new SpecialToken("EOS")),
LlamaText(this._parsedChatTemplate.completionSuffix),
LlamaText(new SpecialTokensText(this._parsedChatTemplate.completionSuffix))
]
};
}
}
function parseChatTemplate(template) {
const parsedTemplate = parseTextTemplate(template, [{
text: "{{systemPrompt}}",
key: "systemPrompt",
optional: true
}, {
text: "{{history}}",
key: "history"
}, {
text: "{{completion}}",
key: "completion"
}]);
if (parsedTemplate.completion.suffix.length == 0)
throw new Error('Chat template must have text after "{{completion}}"');
return {
systemPromptPrefix: parsedTemplate.systemPrompt?.prefix ?? null,
historyPrefix: parsedTemplate.history.prefix,
completionPrefix: parsedTemplate.completion.prefix,
completionSuffix: parsedTemplate.completion.suffix
};
}
function parseChatHistoryTemplate(template) {
const parsedTemplate = parseTextTemplate(template, [{
text: "{{message}}",
key: "message"
}]);
return {
messagePrefix: parsedTemplate.message.prefix,
messageSuffix: parsedTemplate.message.suffix
};
}
//# sourceMappingURL=TemplateChatWrapper.js.map