node-llama-cpp
Version:
Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level
126 lines • 5.53 kB
JavaScript
import { ChatWrapper } from "../ChatWrapper.js";
import { LlamaText, SpecialToken, SpecialTokensText } from "../utils/LlamaText.js";
/**
* This chat wrapper is not safe against chat syntax injection attacks
* ([learn more](https://node-llama-cpp.withcat.ai/guide/llama-text#input-safety-in-node-llama-cpp)).
*/
export class FalconChatWrapper extends ChatWrapper {
wrapperName = "Falcon";
/** @internal */ _userMessageTitle;
/** @internal */ _modelResponseTitle;
/** @internal */ _middleSystemMessageTitle;
/** @internal */ _allowSpecialTokensInTitles;
constructor({ userMessageTitle = "User", modelResponseTitle = "Assistant", middleSystemMessageTitle = "System", allowSpecialTokensInTitles = false } = {}) {
super();
this._userMessageTitle = userMessageTitle;
this._modelResponseTitle = modelResponseTitle;
this._middleSystemMessageTitle = middleSystemMessageTitle;
this._allowSpecialTokensInTitles = allowSpecialTokensInTitles;
}
get userMessageTitle() {
return this._userMessageTitle;
}
get modelResponseTitle() {
return this._modelResponseTitle;
}
get middleSystemMessageTitle() {
return this._middleSystemMessageTitle;
}
generateContextState({ chatHistory, availableFunctions, documentFunctionParams }) {
const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, {
documentParams: documentFunctionParams
});
const resultItems = [];
let systemTexts = [];
let userTexts = [];
let modelTexts = [];
let currentAggregateFocus = null;
function flush() {
if (systemTexts.length > 0 || userTexts.length > 0 || modelTexts.length > 0)
resultItems.push({
system: LlamaText.joinValues("\n\n", systemTexts),
user: LlamaText.joinValues("\n\n", userTexts),
model: LlamaText.joinValues("\n\n", modelTexts)
});
systemTexts = [];
userTexts = [];
modelTexts = [];
}
for (const item of historyWithFunctions) {
if (item.type === "system") {
if (currentAggregateFocus !== "system")
flush();
currentAggregateFocus = "system";
systemTexts.push(LlamaText.fromJSON(item.text));
}
else if (item.type === "user") {
flush();
currentAggregateFocus = null;
userTexts.push(LlamaText(item.text));
}
else if (item.type === "model") {
flush();
currentAggregateFocus = null;
modelTexts.push(this.generateModelResponseText(item.response));
}
else
void item;
}
flush();
const contextText = LlamaText(new SpecialToken("BOS"), resultItems.map(({ system, user, model }, index) => {
const isFirstItem = index === 0;
const isLastItem = index === resultItems.length - 1;
return LlamaText([
(system.values.length === 0)
? LlamaText([])
: LlamaText([
isFirstItem
? LlamaText([])
: SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._middleSystemMessageTitle}: `),
system,
SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n")
]),
(user.values.length === 0)
? LlamaText([])
: LlamaText([
SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._userMessageTitle}: `),
user,
SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n")
]),
(model.values.length === 0 && !isLastItem)
? LlamaText([])
: LlamaText([
SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._modelResponseTitle}: `),
model,
isLastItem
? LlamaText([])
: SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n")
])
]);
}));
return {
contextText,
stopGenerationTriggers: [
LlamaText(new SpecialToken("EOS")),
LlamaText(`\n${this._userMessageTitle}:`),
LlamaText(`\n${this._modelResponseTitle}:`),
LlamaText(`\n${this._middleSystemMessageTitle}:`),
...(!this._allowSpecialTokensInTitles
? []
: [
LlamaText(new SpecialTokensText(`\n${this._userMessageTitle}:`)),
LlamaText(new SpecialTokensText(`\n${this._modelResponseTitle}:`)),
LlamaText(new SpecialTokensText(`\n${this._middleSystemMessageTitle}:`))
])
]
};
}
/** @internal */
static _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() {
return [
{},
{ allowSpecialTokensInTitles: true }
];
}
}
//# sourceMappingURL=FalconChatWrapper.js.map