node-llama-cpp
Version:
Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level
446 lines • 21.1 kB
JavaScript
import { splitText } from "lifecycle-utils";
import { LlamaText, SpecialToken, SpecialTokensText } from "../../../utils/LlamaText.js";
import { getFirstValidResult } from "./getFirstValidResult.js";
export function extractFunctionCallSettingsFromJinjaTemplate({ idsGenerator, renderTemplate }) {
const idToStaticContent = new Map();
const bosTokenId = idsGenerator.generateId();
const eosTokenId = idsGenerator.generateId();
const eotTokenId = idsGenerator.generateId();
idToStaticContent.set(bosTokenId, new SpecialToken("BOS"));
idToStaticContent.set(eosTokenId, new SpecialToken("EOS"));
idToStaticContent.set(eotTokenId, new SpecialToken("EOT"));
const contentIds = new Set();
const addContentId = (id) => {
contentIds.add(id);
return id;
};
const systemMessage = addContentId(idsGenerator.generateId());
const userMessage1 = addContentId(idsGenerator.generateId());
const modelMessage1 = addContentId(idsGenerator.generateId());
const func1name = addContentId(idsGenerator.generateId());
const func1description = addContentId(idsGenerator.generateId());
const func1params = addContentId(idsGenerator.generateId(true));
const func1result = addContentId(idsGenerator.generateId(true));
const func2name = addContentId(idsGenerator.generateId());
const func2description = addContentId(idsGenerator.generateId());
const func2params = addContentId(idsGenerator.generateId(true));
const func2result = addContentId(idsGenerator.generateId(true));
const modelMessage2 = addContentId(idsGenerator.generateId());
const func1StringifyParam = addContentId(idsGenerator.generateId());
const func1StringifyResult = addContentId(idsGenerator.generateId());
const functions1 = {
[func1name]: {
description: func1description,
params: {
type: "number"
}
}
};
const functions2 = {
...functions1,
[func2name]: {
description: func2description,
params: {
type: "number"
}
}
};
const baseChatHistory = [{
type: "system",
text: systemMessage
}, {
type: "user",
text: userMessage1
}];
const chatHistory1Call = [...baseChatHistory, {
type: "model",
response: [
modelMessage1,
{
type: "functionCall",
name: func1name,
// convert to number since this will go through JSON.stringify,
// and we want to avoid escaping characters in the rendered output
params: Number(func1params),
result: Number(func1result),
startsNewChunk: true
},
modelMessage2
]
}];
const chatHistory2Calls = [...baseChatHistory, {
type: "model",
response: [
modelMessage1,
{
type: "functionCall",
name: func1name,
// convert to number since this will go through JSON.stringify,
// and we want to avoid escaping characters in the rendered output
params: Number(func1params),
result: Number(func1result),
startsNewChunk: true
},
{
type: "functionCall",
name: func2name,
params: Number(func2params),
result: Number(func2result),
startsNewChunk: false
},
modelMessage2
]
}];
const chatHistory2CallsNewChunk = [...baseChatHistory, {
type: "model",
response: [
modelMessage1,
{
type: "functionCall",
name: func1name,
// convert to number since this will go through JSON.stringify,
// and we want to avoid escaping characters in the rendered output
params: Number(func1params),
result: Number(func1result),
startsNewChunk: true
},
{
type: "functionCall",
name: func2name,
params: Number(func2params),
result: Number(func2result),
startsNewChunk: true
},
modelMessage2
]
}];
const additionalParams = {
"bos_token": bosTokenId,
"eos_token": eosTokenId,
"eot_token": eotTokenId
};
let combineModelMessageAndToolCalls = true;
let stringifyParams = true;
let stringifyResult = true;
try {
const paramsObjectTest = renderTemplate({
chatHistory: [...baseChatHistory, {
type: "model",
response: [
modelMessage1,
{
type: "functionCall",
name: func1name,
params: { [func1StringifyParam]: "test" },
result: func1StringifyResult,
startsNewChunk: true
},
modelMessage2
]
}],
functions: functions1,
additionalParams,
stringifyFunctionParams: false,
stringifyFunctionResults: false,
combineModelMessageAndToolCalls
});
stringifyParams = (!paramsObjectTest.includes(`"${func1StringifyParam}":`) &&
!paramsObjectTest.includes(`'${func1StringifyParam}':`));
}
catch (err) {
// do nothing
}
try {
const resultObjectTest = renderTemplate({
chatHistory: [...baseChatHistory, {
type: "model",
response: [
modelMessage1,
{
type: "functionCall",
name: func1name,
params: func1StringifyParam,
result: { [func1StringifyResult]: "test" },
startsNewChunk: true
},
modelMessage2
]
}],
functions: functions1,
additionalParams,
stringifyFunctionParams: false,
stringifyFunctionResults: false,
combineModelMessageAndToolCalls
});
stringifyResult = (!resultObjectTest.includes(`"${func1StringifyResult}":`) &&
!resultObjectTest.includes(`'${func1StringifyResult}':`));
}
catch (err) {
// do nothing
}
combineModelMessageAndToolCalls = renderTemplate({
chatHistory: chatHistory1Call,
functions: functions1,
additionalParams,
stringifyFunctionParams: true,
stringifyFunctionResults: true,
combineModelMessageAndToolCalls
}).includes(modelMessage1);
let textBetween2TextualModelResponses = LlamaText();
if (!combineModelMessageAndToolCalls) {
try {
const betweenModelTextualResponsesTest = renderTemplate({
chatHistory: [...baseChatHistory, {
type: "model",
response: [modelMessage1]
}, {
type: "model",
response: [modelMessage2]
}],
functions: {},
additionalParams,
stringifyFunctionParams: false,
stringifyFunctionResults: false,
combineModelMessageAndToolCalls,
squashModelTextResponses: false
});
const textDiff = getTextBetweenIds(betweenModelTextualResponsesTest, modelMessage1, modelMessage2).text ?? "";
textBetween2TextualModelResponses = reviveSeparatorText(textDiff, idToStaticContent, contentIds);
}
catch (err) {
// do nothing
}
}
let usedNewChunkFor2Calls = false;
const rendered1Call = renderTemplate({
chatHistory: chatHistory1Call,
functions: functions1,
additionalParams,
stringifyFunctionParams: stringifyParams,
stringifyFunctionResults: stringifyResult,
combineModelMessageAndToolCalls
});
const rendered2Calls = getFirstValidResult([
() => renderTemplate({
chatHistory: chatHistory2Calls,
functions: functions2,
additionalParams,
stringifyFunctionParams: stringifyParams,
stringifyFunctionResults: stringifyResult,
combineModelMessageAndToolCalls
}),
() => {
usedNewChunkFor2Calls = true;
return renderTemplate({
chatHistory: chatHistory2CallsNewChunk,
functions: functions2,
additionalParams,
stringifyFunctionParams: stringifyParams,
stringifyFunctionResults: stringifyResult,
combineModelMessageAndToolCalls
});
}
]);
const modelMessage1ToFunc1Name = getTextBetweenIds(rendered2Calls, modelMessage1, func1name);
const func1NameToFunc1Params = getTextBetweenIds(rendered2Calls, func1name, func1params, modelMessage1ToFunc1Name.endIndex);
const func1ResultIndex = rendered2Calls.indexOf(func1result, func1NameToFunc1Params.endIndex);
const func2NameIndex = rendered2Calls.indexOf(func2name, modelMessage1ToFunc1Name.endIndex);
if (modelMessage1ToFunc1Name.text == null ||
func1NameToFunc1Params.text == null ||
func1ResultIndex < 0 ||
func2NameIndex < 0)
return { settings: null, stringifyParams, stringifyResult, combineModelMessageAndToolCalls };
const supportsParallelCalls = func1ResultIndex > func2NameIndex;
if (!supportsParallelCalls || usedNewChunkFor2Calls) {
const prefix = getTextBetweenIds(rendered1Call, modelMessage1, func1name);
const paramsPrefix = getTextBetweenIds(rendered1Call, func1name, func1params, prefix.endIndex);
const resultPrefix = getTextBetweenIds(rendered1Call, func1params, func1result, paramsPrefix.endIndex);
const resultSuffix = getTextBetweenIds(rendered1Call, func1result, modelMessage2, resultPrefix.endIndex);
if (prefix.text == null || prefix.text === "" || paramsPrefix.text == null || resultPrefix.text == null || resultSuffix.text == null)
return { settings: null, stringifyParams, stringifyResult, combineModelMessageAndToolCalls };
return {
stringifyParams,
stringifyResult,
combineModelMessageAndToolCalls,
settings: {
call: {
optionalPrefixSpace: true,
prefix: removeCommonRevivedPrefix(reviveSeparatorText(prefix.text, idToStaticContent, contentIds), !combineModelMessageAndToolCalls
? textBetween2TextualModelResponses
: LlamaText()),
paramsPrefix: reviveSeparatorText(paramsPrefix.text, idToStaticContent, contentIds),
suffix: "",
emptyCallParamsPlaceholder: {}
},
result: {
prefix: reviveSeparatorText(resultPrefix.text, new Map([
...idToStaticContent.entries(),
[func1name, LlamaText("{{functionName}}")],
[func1params, LlamaText("{{functionParams}}")]
]), contentIds),
suffix: reviveSeparatorText(resultSuffix.text, new Map([
...idToStaticContent.entries(),
[func1name, LlamaText("{{functionName}}")],
[func1params, LlamaText("{{functionParams}}")]
]), contentIds)
}
}
};
}
const func1ParamsToFunc2Name = getTextBetweenIds(rendered2Calls, func1params, func2name, func1NameToFunc1Params.endIndex);
const func2ParamsToFunc1Result = getTextBetweenIds(rendered2Calls, func2params, func1result, func1ParamsToFunc2Name.endIndex);
const func1ResultToFunc2Result = getTextBetweenIds(rendered2Calls, func1result, func2result, func2ParamsToFunc1Result.endIndex);
const func2ResultToModelMessage2 = getTextBetweenIds(rendered2Calls, func2result, modelMessage2, func1ResultToFunc2Result.endIndex);
if (func1ParamsToFunc2Name.text == null || func2ParamsToFunc1Result.text == null || func1ResultToFunc2Result.text == null ||
func2ResultToModelMessage2.text == null)
return { settings: null, stringifyParams, stringifyResult, combineModelMessageAndToolCalls };
const callPrefixLength = findCommonEndLength(modelMessage1ToFunc1Name.text, func1ParamsToFunc2Name.text);
const callPrefixText = func1ParamsToFunc2Name.text.slice(-callPrefixLength);
const parallelismCallPrefix = modelMessage1ToFunc1Name.text.slice(0, -callPrefixLength);
const callSuffixLength = findCommandStartLength(func1ParamsToFunc2Name.text, func2ParamsToFunc1Result.text);
const callSuffixText = func1ParamsToFunc2Name.text.slice(0, callSuffixLength);
const parallelismBetweenCallsText = func1ParamsToFunc2Name.text.slice(callSuffixLength, -callPrefixLength);
const callParamsPrefixText = func1NameToFunc1Params.text;
const resultPrefixLength = findCommonEndLength(func2ParamsToFunc1Result.text, func1ResultToFunc2Result.text);
const resultPrefixText = func2ParamsToFunc1Result.text.slice(-resultPrefixLength);
const resultSuffixLength = findCommandStartLength(func1ResultToFunc2Result.text, func2ResultToModelMessage2.text);
const resultSuffixText = func1ResultToFunc2Result.text.slice(0, resultSuffixLength);
const parallelismResultBetweenResultsText = func1ResultToFunc2Result.text.slice(resultSuffixLength, -resultPrefixLength);
const parallelismResultSuffixText = func2ResultToModelMessage2.text.slice(resultSuffixLength);
const resolveParallelismBetweenSectionsParts = (betweenSectionsText) => {
const { index: endTokenIndex, text: endTokenId } = findFirstTextMatch(betweenSectionsText, [eosTokenId, eosTokenId]);
if (endTokenIndex >= 0 && endTokenId != null)
return {
parallelismCallSuffixText: betweenSectionsText.slice(0, endTokenIndex + endTokenId.length),
parallelismResultPrefix: betweenSectionsText.slice(endTokenIndex + endTokenId.length)
};
const bosIndex = betweenSectionsText.indexOf(bosTokenId);
if (bosIndex >= 0)
return {
parallelismCallSuffixText: betweenSectionsText.slice(0, bosIndex),
parallelismResultPrefix: betweenSectionsText.slice(bosIndex)
};
return {
parallelismCallSuffixText: betweenSectionsText,
parallelismResultPrefix: ""
};
};
const { parallelismCallSuffixText, parallelismResultPrefix } = resolveParallelismBetweenSectionsParts(func2ParamsToFunc1Result.text.slice(callSuffixLength, -resultPrefixLength));
return {
stringifyParams,
stringifyResult,
combineModelMessageAndToolCalls,
settings: {
call: {
optionalPrefixSpace: true,
prefix: reviveSeparatorText(callPrefixText, idToStaticContent, contentIds),
paramsPrefix: reviveSeparatorText(callParamsPrefixText, idToStaticContent, contentIds),
suffix: reviveSeparatorText(callSuffixText, idToStaticContent, contentIds),
emptyCallParamsPlaceholder: {}
},
result: {
prefix: reviveSeparatorText(resultPrefixText, new Map([
...idToStaticContent.entries(),
[func1name, LlamaText("{{functionName}}")],
[func1params, LlamaText("{{functionParams}}")]
]), contentIds),
suffix: reviveSeparatorText(resultSuffixText, new Map([
...idToStaticContent.entries(),
[func1name, LlamaText("{{functionName}}")],
[func1params, LlamaText("{{functionParams}}")]
]), contentIds)
},
parallelism: {
call: {
sectionPrefix: removeCommonRevivedPrefix(reviveSeparatorText(parallelismCallPrefix, idToStaticContent, contentIds), !combineModelMessageAndToolCalls
? textBetween2TextualModelResponses
: LlamaText()),
betweenCalls: reviveSeparatorText(parallelismBetweenCallsText, idToStaticContent, contentIds),
sectionSuffix: reviveSeparatorText(parallelismCallSuffixText, idToStaticContent, contentIds)
},
result: {
sectionPrefix: reviveSeparatorText(parallelismResultPrefix, idToStaticContent, contentIds),
betweenResults: reviveSeparatorText(parallelismResultBetweenResultsText, idToStaticContent, contentIds),
sectionSuffix: reviveSeparatorText(parallelismResultSuffixText, idToStaticContent, contentIds)
}
}
}
};
}
function getTextBetweenIds(text, startId, endId, startIndex = 0) {
const foundStartIndex = text.indexOf(startId, startIndex);
if (foundStartIndex < 0)
return { text: undefined, endIndex: -1 };
const foundEndIndex = text.indexOf(endId, foundStartIndex + startId.length);
if (foundEndIndex < 0)
return { text: undefined, endIndex: -1 };
return {
text: text.slice(foundStartIndex + startId.length, foundEndIndex),
endIndex: foundEndIndex
};
}
function reviveSeparatorText(text, idMap, contentIds) {
return LlamaText(splitText(text, [...new Set([...idMap.keys(), ...contentIds])])
.map((item) => {
if (typeof item === "string")
return new SpecialTokensText(item);
const mappedItem = idMap.get(item.separator);
if (mappedItem != null)
return mappedItem;
if (contentIds.has(item.separator))
throw new Error("Content ID found in separator text");
return new SpecialTokensText(item.separator);
}));
}
function removeCommonRevivedPrefix(target, matchStart) {
for (let commonStartLength = 0; commonStartLength < target.values.length && commonStartLength < matchStart.values.length; commonStartLength++) {
const targetValue = target.values[commonStartLength];
const matchStartValue = matchStart.values[commonStartLength];
if (typeof targetValue === "string" && typeof matchStartValue === "string") {
if (targetValue === matchStartValue)
continue;
}
else if (targetValue instanceof SpecialTokensText && matchStartValue instanceof SpecialTokensText) {
const commonLength = findCommandStartLength(targetValue.value, matchStartValue.value);
if (commonLength === targetValue.value.length && commonLength === matchStartValue.value.length)
continue;
return LlamaText([
new SpecialTokensText(targetValue.value.slice(commonLength)),
...target.values.slice(commonStartLength + 1)
]);
}
else if (targetValue instanceof SpecialToken && matchStartValue instanceof SpecialToken) {
if (targetValue.value === matchStartValue.value)
continue;
}
return LlamaText(target.values.slice(commonStartLength));
}
return LlamaText(target.values.slice(matchStart.values.length));
}
function findCommandStartLength(text1, text2) {
let commonStartLength = 0;
while (commonStartLength < text1.length && commonStartLength < text2.length) {
if (text1[commonStartLength] !== text2[commonStartLength])
break;
commonStartLength++;
}
return commonStartLength;
}
function findCommonEndLength(text1, text2) {
let commonEndLength = 0;
while (commonEndLength < text1.length && commonEndLength < text2.length) {
if (text1[text1.length - commonEndLength - 1] !== text2[text2.length - commonEndLength - 1])
break;
commonEndLength++;
}
return commonEndLength;
}
function findFirstTextMatch(text, matchTexts, startIndex = 0) {
for (const matchText of matchTexts) {
const index = text.indexOf(matchText, startIndex);
if (index >= 0)
return { index, text: matchText };
}
return { index: -1, text: undefined };
}
//# sourceMappingURL=extractFunctionCallSettingsFromJinjaTemplate.js.map