UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

126 lines 5.53 kB
import { ChatWrapper } from "../ChatWrapper.js"; import { LlamaText, SpecialToken, SpecialTokensText } from "../utils/LlamaText.js"; /** * This chat wrapper is not safe against chat syntax injection attacks * ([learn more](https://node-llama-cpp.withcat.ai/guide/llama-text#input-safety-in-node-llama-cpp)). */ export class FalconChatWrapper extends ChatWrapper { wrapperName = "Falcon"; /** @internal */ _userMessageTitle; /** @internal */ _modelResponseTitle; /** @internal */ _middleSystemMessageTitle; /** @internal */ _allowSpecialTokensInTitles; constructor({ userMessageTitle = "User", modelResponseTitle = "Assistant", middleSystemMessageTitle = "System", allowSpecialTokensInTitles = false } = {}) { super(); this._userMessageTitle = userMessageTitle; this._modelResponseTitle = modelResponseTitle; this._middleSystemMessageTitle = middleSystemMessageTitle; this._allowSpecialTokensInTitles = allowSpecialTokensInTitles; } get userMessageTitle() { return this._userMessageTitle; } get modelResponseTitle() { return this._modelResponseTitle; } get middleSystemMessageTitle() { return this._middleSystemMessageTitle; } generateContextState({ chatHistory, availableFunctions, documentFunctionParams }) { const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, { documentParams: documentFunctionParams }); const resultItems = []; let systemTexts = []; let userTexts = []; let modelTexts = []; let currentAggregateFocus = null; function flush() { if (systemTexts.length > 0 || userTexts.length > 0 || modelTexts.length > 0) resultItems.push({ system: LlamaText.joinValues("\n\n", systemTexts), user: LlamaText.joinValues("\n\n", userTexts), model: LlamaText.joinValues("\n\n", modelTexts) }); systemTexts = []; userTexts = []; modelTexts = []; } for (const item of historyWithFunctions) { if (item.type === "system") { if (currentAggregateFocus !== "system") flush(); currentAggregateFocus = "system"; systemTexts.push(LlamaText.fromJSON(item.text)); } else if (item.type === "user") { flush(); currentAggregateFocus = null; userTexts.push(LlamaText(item.text)); } else if (item.type === "model") { flush(); currentAggregateFocus = null; modelTexts.push(this.generateModelResponseText(item.response)); } else void item; } flush(); const contextText = LlamaText(new SpecialToken("BOS"), resultItems.map(({ system, user, model }, index) => { const isFirstItem = index === 0; const isLastItem = index === resultItems.length - 1; return LlamaText([ (system.values.length === 0) ? LlamaText([]) : LlamaText([ isFirstItem ? LlamaText([]) : SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._middleSystemMessageTitle}: `), system, SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (user.values.length === 0) ? LlamaText([]) : LlamaText([ SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._userMessageTitle}: `), user, SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (model.values.length === 0 && !isLastItem) ? LlamaText([]) : LlamaText([ SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._modelResponseTitle}: `), model, isLastItem ? LlamaText([]) : SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]) ]); })); return { contextText, stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), LlamaText(`\n${this._userMessageTitle}:`), LlamaText(`\n${this._modelResponseTitle}:`), LlamaText(`\n${this._middleSystemMessageTitle}:`), ...(!this._allowSpecialTokensInTitles ? [] : [ LlamaText(new SpecialTokensText(`\n${this._userMessageTitle}:`)), LlamaText(new SpecialTokensText(`\n${this._modelResponseTitle}:`)), LlamaText(new SpecialTokensText(`\n${this._middleSystemMessageTitle}:`)) ]) ] }; } /** @internal */ static _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { return [ {}, { allowSpecialTokensInTitles: true } ]; } } //# sourceMappingURL=FalconChatWrapper.js.map