UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

96 lines 4.01 kB
import { ChatWrapper } from "../ChatWrapper.js"; import { SpecialToken, LlamaText, SpecialTokensText } from "../utils/LlamaText.js"; // source: https://ai.google.dev/gemma/docs/formatting // source: https://www.promptingguide.ai/models/gemma export class GemmaChatWrapper extends ChatWrapper { wrapperName = "Gemma"; settings = { ...ChatWrapper.defaultSettings, supportsSystemMessages: false }; generateContextState({ chatHistory, availableFunctions, documentFunctionParams }) { const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, { documentParams: documentFunctionParams }); const resultItems = []; let systemTexts = []; let userTexts = []; let modelTexts = []; let currentAggregateFocus = null; function flush() { if (systemTexts.length > 0 || userTexts.length > 0 || modelTexts.length > 0) { const systemText = LlamaText.joinValues("\n\n", systemTexts); let userText = LlamaText.joinValues("\n\n", userTexts); // there's no system prompt support in Gemma, so we'll prepend the system text to the user message if (systemText.values.length > 0) { if (userText.values.length === 0) userText = systemText; else userText = LlamaText([ systemText, "\n\n---\n\n", userText ]); } resultItems.push({ user: userText, model: LlamaText.joinValues("\n\n", modelTexts) }); } systemTexts = []; userTexts = []; modelTexts = []; } for (const item of historyWithFunctions) { if (item.type === "system") { if (currentAggregateFocus !== "system") flush(); currentAggregateFocus = "system"; systemTexts.push(LlamaText.fromJSON(item.text)); } else if (item.type === "user") { if (currentAggregateFocus !== "system" && currentAggregateFocus !== "user") flush(); currentAggregateFocus = "user"; userTexts.push(LlamaText(item.text)); } else if (item.type === "model") { currentAggregateFocus = "model"; modelTexts.push(this.generateModelResponseText(item.response)); } else void item; } flush(); const contextText = LlamaText(new SpecialToken("BOS"), resultItems.map(({ user, model }, index) => { const isLastItem = index === resultItems.length - 1; return LlamaText([ (user.values.length === 0) ? LlamaText([]) : LlamaText([ new SpecialTokensText("<start_of_turn>user\n"), user, new SpecialTokensText("<end_of_turn>\n") ]), (model.values.length === 0 && !isLastItem) ? LlamaText([]) : LlamaText([ new SpecialTokensText("<start_of_turn>model\n"), model, isLastItem ? LlamaText([]) : new SpecialTokensText("<end_of_turn>\n") ]) ]); })); return { contextText, stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialTokensText("<end_of_turn>\n")), LlamaText("<end_of_turn>") ] }; } } //# sourceMappingURL=GemmaChatWrapper.js.map