node-llama-cpp
Version:
Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level
96 lines • 4.01 kB
JavaScript
import { ChatWrapper } from "../ChatWrapper.js";
import { SpecialToken, LlamaText, SpecialTokensText } from "../utils/LlamaText.js";
// source: https://ai.google.dev/gemma/docs/formatting
// source: https://www.promptingguide.ai/models/gemma
export class GemmaChatWrapper extends ChatWrapper {
wrapperName = "Gemma";
settings = {
...ChatWrapper.defaultSettings,
supportsSystemMessages: false
};
generateContextState({ chatHistory, availableFunctions, documentFunctionParams }) {
const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, {
documentParams: documentFunctionParams
});
const resultItems = [];
let systemTexts = [];
let userTexts = [];
let modelTexts = [];
let currentAggregateFocus = null;
function flush() {
if (systemTexts.length > 0 || userTexts.length > 0 || modelTexts.length > 0) {
const systemText = LlamaText.joinValues("\n\n", systemTexts);
let userText = LlamaText.joinValues("\n\n", userTexts);
// there's no system prompt support in Gemma, so we'll prepend the system text to the user message
if (systemText.values.length > 0) {
if (userText.values.length === 0)
userText = systemText;
else
userText = LlamaText([
systemText,
"\n\n---\n\n",
userText
]);
}
resultItems.push({
user: userText,
model: LlamaText.joinValues("\n\n", modelTexts)
});
}
systemTexts = [];
userTexts = [];
modelTexts = [];
}
for (const item of historyWithFunctions) {
if (item.type === "system") {
if (currentAggregateFocus !== "system")
flush();
currentAggregateFocus = "system";
systemTexts.push(LlamaText.fromJSON(item.text));
}
else if (item.type === "user") {
if (currentAggregateFocus !== "system" && currentAggregateFocus !== "user")
flush();
currentAggregateFocus = "user";
userTexts.push(LlamaText(item.text));
}
else if (item.type === "model") {
currentAggregateFocus = "model";
modelTexts.push(this.generateModelResponseText(item.response));
}
else
void item;
}
flush();
const contextText = LlamaText(new SpecialToken("BOS"), resultItems.map(({ user, model }, index) => {
const isLastItem = index === resultItems.length - 1;
return LlamaText([
(user.values.length === 0)
? LlamaText([])
: LlamaText([
new SpecialTokensText("<start_of_turn>user\n"),
user,
new SpecialTokensText("<end_of_turn>\n")
]),
(model.values.length === 0 && !isLastItem)
? LlamaText([])
: LlamaText([
new SpecialTokensText("<start_of_turn>model\n"),
model,
isLastItem
? LlamaText([])
: new SpecialTokensText("<end_of_turn>\n")
])
]);
}));
return {
contextText,
stopGenerationTriggers: [
LlamaText(new SpecialToken("EOS")),
LlamaText(new SpecialTokensText("<end_of_turn>\n")),
LlamaText("<end_of_turn>")
]
};
}
}
//# sourceMappingURL=GemmaChatWrapper.js.map