UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

45 lines 2.43 kB
import { LlamaGrammar } from "../../LlamaGrammar.js"; import { LlamaText } from "../../../utils/LlamaText.js"; import { validateObjectAgainstGbnfSchema } from "../../../utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.js"; import { GbnfGrammarGenerator } from "../../../utils/gbnfJson/GbnfGrammarGenerator.js"; import { getGbnfJsonTerminalForGbnfJsonSchema } from "../../../utils/gbnfJson/utils/getGbnfJsonTerminalForGbnfJsonSchema.js"; import { LlamaFunctionCallValidationError } from "./LlamaFunctionCallValidationError.js"; export class FunctionCallParamsGrammar extends LlamaGrammar { _functions; _chatWrapper; _functionName; _paramsSchema; constructor(llama, functions, chatWrapper, functionName, paramsSchema) { const grammar = getGbnfGrammarForFunctionParams(paramsSchema); super(llama, { grammar, stopGenerationTriggers: [LlamaText("\n".repeat(4))], trimWhitespaceSuffix: true }); this._functions = functions; this._chatWrapper = chatWrapper; this._functionName = functionName; this._paramsSchema = paramsSchema; } parseParams(callText) { const endIndex = callText.lastIndexOf("\n".repeat(4)); if (endIndex < 0) throw new LlamaFunctionCallValidationError(`Expected function call params for function "${this._functionName}" to end with stop generation trigger`, this._functions, this._chatWrapper, callText); const paramsString = callText.slice(0, endIndex); if (paramsString.trim().length === 0) throw new LlamaFunctionCallValidationError(`Expected function call params for function "${this._functionName}" to not be empty`, this._functions, this._chatWrapper, callText); const params = JSON.parse(paramsString); validateObjectAgainstGbnfSchema(params, this._paramsSchema); return { params: params, // prevent infinite TS type instantiation raw: paramsString }; } } function getGbnfGrammarForFunctionParams(paramsSchema) { const grammarGenerator = new GbnfGrammarGenerator(); const rootTerminal = getGbnfJsonTerminalForGbnfJsonSchema(paramsSchema, grammarGenerator); const rootGrammar = rootTerminal.getGrammar(grammarGenerator); return grammarGenerator.generateGbnfFile(rootGrammar + ` "${"\\n".repeat(4)}"`); } //# sourceMappingURL=FunctionCallParamsGrammar.js.map