inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
713 lines • 30.9 kB
JavaScript
import path from 'node:path';
import fs from 'node:fs';
import { nanoid } from 'nanoid';
import { getLlama, LlamaChat, LlamaCompletion, LlamaLogLevel, TokenBias, LlamaGrammar, defineChatSessionFunction, createModelDownloader, LlamaJsonSchemaGrammar, } from 'node-llama-cpp';
import { LogLevels } from '../../lib/logger.js';
import { flattenMessageTextContent } from '../../lib/flattenMessageTextContent.js';
import { acquireFileLock } from '../../lib/acquireFileLock.js';
import { getRandomNumber } from '../../lib/util.js';
import { validateModelFile } from '../../lib/validateModelFile.js';
import { createChatMessageArray, addFunctionCallToChatHistory, mapFinishReason } from './util.js';
export const autoGpu = true;
export async function prepareModel({ config, log }, onProgress, signal) {
fs.mkdirSync(path.dirname(config.location), { recursive: true });
const releaseFileLock = await acquireFileLock(config.location, signal);
if (signal?.aborted) {
releaseFileLock();
return;
}
log(LogLevels.info, `Preparing node-llama-cpp model at ${config.location}`, {
model: config.id,
});
const downloadModel = async (url, validationResult) => {
log(LogLevels.info, `Downloading model files`, {
model: config.id,
url: url,
location: config.location,
error: validationResult,
});
const downloader = await createModelDownloader({
modelUrl: url,
dirPath: path.dirname(config.location),
fileName: path.basename(config.location),
deleteTempFileOnCancel: false,
onProgress: (status) => {
if (onProgress) {
onProgress({
file: config.location,
loadedBytes: status.downloadedSize,
totalBytes: status.totalSize,
});
}
},
});
await downloader.download();
};
try {
if (signal?.aborted) {
return;
}
const validationRes = await validateModelFile(config);
let modelMeta = validationRes.meta;
if (signal?.aborted) {
return;
}
if (validationRes.error) {
if (!config.url) {
throw new Error(`${validationRes.error} - No URL provided`);
}
await downloadModel(config.url, validationRes.error);
const revalidationRes = await validateModelFile(config);
if (revalidationRes.error) {
throw new Error(`Downloaded files are invalid: ${revalidationRes.error}`);
}
modelMeta = revalidationRes.meta;
}
return modelMeta;
}
catch (err) {
throw err;
}
finally {
releaseFileLock();
}
}
export async function createInstance({ config, log }, signal) {
log(LogLevels.debug, 'Load Llama model', config.device);
// takes "auto" | "metal" | "cuda" | "vulkan"
const gpuSetting = (config.device?.gpu ?? 'auto');
const llama = await getLlama({
gpu: gpuSetting,
// forwarding llama logger
logLevel: LlamaLogLevel.debug,
logger: (level, message) => {
if (level === LlamaLogLevel.warn) {
log(LogLevels.warn, message);
}
else if (level === LlamaLogLevel.error || level === LlamaLogLevel.fatal) {
log(LogLevels.error, message);
}
else if (level === LlamaLogLevel.info || level === LlamaLogLevel.debug) {
log(LogLevels.verbose, message);
}
},
});
const llamaGrammars = {
json: await LlamaGrammar.getFor(llama, 'json'),
};
if (config.grammars) {
for (const key in config.grammars) {
const input = config.grammars[key];
if (typeof input === 'string') {
llamaGrammars[key] = new LlamaGrammar(llama, {
grammar: input,
});
}
else {
// assume input is a JSON schema object
llamaGrammars[key] = new LlamaJsonSchemaGrammar(llama, input);
}
}
}
const llamaModel = await llama.loadModel({
modelPath: config.location, // full model absolute path
loadSignal: signal,
useMlock: config.device?.memLock ?? false,
gpuLayers: config.device?.gpuLayers,
// onLoadProgress: (percent) => {}
});
const context = await llamaModel.createContext({
sequences: 1,
lora: config.lora,
threads: config.device?.cpuThreads,
batchSize: config.batchSize,
contextSize: config.contextSize,
flashAttention: true,
createSignal: signal,
});
const instance = {
model: llamaModel,
context,
grammars: llamaGrammars,
chat: undefined,
chatHistory: [],
pendingFunctionCalls: {},
lastEvaluation: undefined,
completion: undefined,
contextSequence: context.getSequence(),
chatWrapper: config.chatWrapper,
};
if (config.initialMessages) {
const initialChatHistory = createChatMessageArray(config.initialMessages);
const chat = new LlamaChat({
contextSequence: instance.contextSequence,
chatWrapper: instance.chatWrapper,
// autoDisposeSequence: true,
});
let inputFunctions;
if (config.tools?.definitions && Object.keys(config.tools.definitions).length > 0) {
const functionDefs = config.tools.definitions;
inputFunctions = {};
for (const functionName in functionDefs) {
const functionDef = functionDefs[functionName];
inputFunctions[functionName] = defineChatSessionFunction({
description: functionDef.description,
params: functionDef.parameters,
handler: functionDef.handler || (() => { }),
});
}
}
const loadMessagesRes = await chat.loadChatAndCompleteUserMessage(initialChatHistory, {
initialUserPrompt: '',
functions: inputFunctions,
documentFunctionParams: config.tools?.documentParams,
});
instance.chat = chat;
instance.chatHistory = initialChatHistory;
instance.lastEvaluation = {
cleanHistory: initialChatHistory,
contextWindow: loadMessagesRes.lastEvaluation.contextWindow,
contextShiftMetadata: loadMessagesRes.lastEvaluation.contextShiftMetadata,
};
}
if (config.prefix) {
const contextSequence = instance.contextSequence;
const completion = new LlamaCompletion({
contextSequence: contextSequence,
});
await completion.generateCompletion(config.prefix, {
maxTokens: 0,
});
instance.completion = completion;
instance.contextSequence = contextSequence;
}
return instance;
}
export async function disposeInstance(instance) {
await instance.model.dispose();
}
export async function processChatCompletionTask(task, ctx, signal) {
const { instance, resetContext, config, log } = ctx;
if (!instance.chat || resetContext) {
log(LogLevels.debug, 'Recreating chat context', {
resetContext: resetContext,
willDisposeChat: !!instance.chat,
});
// if context reset is requested, dispose the chat instance
if (instance.chat) {
instance.chat.dispose();
}
let contextSequence = instance.contextSequence;
if (!contextSequence || contextSequence.disposed) {
if (instance.context.sequencesLeft) {
contextSequence = instance.context.getSequence();
instance.contextSequence = contextSequence;
}
else {
throw new Error('No context sequence available');
}
}
else {
contextSequence.clearHistory();
}
instance.chat = new LlamaChat({
contextSequence: contextSequence,
chatWrapper: instance.chatWrapper,
// autoDisposeSequence: true,
});
// reset state and reingest the conversation history
instance.lastEvaluation = undefined;
instance.pendingFunctionCalls = {};
instance.chatHistory = createChatMessageArray(task.messages);
// drop last user message. its gonna be added later, after resolved function calls
if (instance.chatHistory[instance.chatHistory.length - 1].type === 'user') {
instance.chatHistory.pop();
}
}
// set additional stop generation triggers for this completion
const customStopTriggers = [];
const stopTrigger = task.stop ?? config.completionDefaults?.stop;
if (stopTrigger) {
customStopTriggers.push(...stopTrigger.map((t) => [t]));
}
// setting up logit/token bias dictionary
let tokenBias;
const completionTokenBias = task.tokenBias ?? config.completionDefaults?.tokenBias;
if (completionTokenBias) {
tokenBias = new TokenBias(instance.model.tokenizer);
for (const key in completionTokenBias) {
const bias = completionTokenBias[key] / 10;
const tokenId = parseInt(key);
if (!isNaN(tokenId)) {
tokenBias.set(tokenId, bias);
}
else {
tokenBias.set(key, bias);
}
}
}
// setting up available function definitions
const functionDefinitions = {
...config.tools?.definitions,
...task.tools?.definitions,
};
// see if the user submitted any function call results
const maxParallelCalls = task.tools?.maxParallelCalls ?? config.tools?.maxParallelCalls;
const chatWrapperSupportsParallelism = !!instance.chat.chatWrapper.settings.functions.parallelism;
const supportsParallelFunctionCalling = chatWrapperSupportsParallelism && !!maxParallelCalls;
const resolvedFunctionCalls = [];
const functionCallResultMessages = task.messages.filter((m) => m.role === 'tool');
let startsNewChunk = supportsParallelFunctionCalling;
for (const message of functionCallResultMessages) {
if (!instance.pendingFunctionCalls[message.callId]) {
log(LogLevels.warn, `Received function result for non-existing call id "${message.callId}`);
continue;
}
log(LogLevels.debug, 'Resolving pending function call', {
id: message.callId,
result: message.content,
});
const functionCall = instance.pendingFunctionCalls[message.callId];
const functionDef = functionDefinitions[functionCall.functionName];
const resolvedFunctionCall = {
type: 'functionCall',
name: functionCall.functionName,
description: functionDef?.description,
params: functionCall.params,
result: message.content,
rawCall: functionCall.raw,
};
if (startsNewChunk) {
resolvedFunctionCall.startsNewChunk = true;
startsNewChunk = false;
}
resolvedFunctionCalls.push(resolvedFunctionCall);
delete instance.pendingFunctionCalls[message.callId];
}
// only grammar or functions can be used, not both.
// currently ignoring function definitions if grammar is provided
let inputGrammar;
let inputFunctions;
if (task.grammar) {
if (!instance.grammars[task.grammar]) {
throw new Error(`Grammar "${task.grammar}" not found.`);
}
inputGrammar = instance.grammars[task.grammar];
}
else if (Object.keys(functionDefinitions).length > 0) {
inputFunctions = {};
for (const functionName in functionDefinitions) {
const functionDef = functionDefinitions[functionName];
inputFunctions[functionName] = defineChatSessionFunction({
description: functionDef.description,
params: functionDef.parameters,
handler: functionDef.handler || (() => { }),
});
}
}
let lastEvaluation = instance.lastEvaluation;
const appendResolvedFunctionCalls = (history, response) => {
const lastMessage = history[history.length - 1];
// append to existing response item if last message in history is a model response
if (lastMessage.type === 'model') {
const lastMessageResponse = lastMessage;
if (Array.isArray(response)) {
lastMessageResponse.response.push(...response);
// if we dont add a fresh empty message llama 3.2 3b will keep trying to call functions, not sure why this is
history.push({
type: 'model',
response: [],
});
}
return;
}
// otherwise append a new one with the calls
history.push({
type: 'model',
response: response,
});
};
// if the incoming messages resolved any pending function calls, add them to history
if (resolvedFunctionCalls.length) {
appendResolvedFunctionCalls(instance.chatHistory, resolvedFunctionCalls);
if (lastEvaluation?.contextWindow) {
appendResolvedFunctionCalls(lastEvaluation.contextWindow, resolvedFunctionCalls);
}
}
// add the new user message to the chat history
let assistantPrefill = '';
const lastMessage = task.messages[task.messages.length - 1];
if (lastMessage.role === 'user' && lastMessage.content) {
const newUserText = flattenMessageTextContent(lastMessage.content);
if (newUserText) {
instance.chatHistory.push({
type: 'user',
text: newUserText,
});
}
}
else if (lastMessage.role === 'assistant') {
// use last message as prefill for response, if its an assistant message
assistantPrefill = flattenMessageTextContent(lastMessage.content);
}
else if (!resolvedFunctionCalls.length) {
log(LogLevels.warn, 'Last message is not valid for chat completion. This is likely a mistake.', lastMessage);
throw new Error('Invalid last chat message');
}
const defaults = config.completionDefaults ?? {};
let newChatHistory = instance.chatHistory.slice();
let newContextWindowChatHistory = !lastEvaluation?.contextWindow ? undefined : instance.chatHistory.slice();
if (instance.chatHistory[instance.chatHistory.length - 1].type !== 'model' || assistantPrefill) {
const newModelResponse = assistantPrefill ? [assistantPrefill] : [];
newChatHistory.push({
type: 'model',
response: newModelResponse,
});
if (newContextWindowChatHistory) {
newContextWindowChatHistory.push({
type: 'model',
response: newModelResponse,
});
}
}
const functionsOrGrammar = inputFunctions
? {
// clone the input funcs because the dict gets mutated in the loop below to enable preventFurtherCalls
functions: { ...inputFunctions },
documentFunctionParams: task.tools?.documentParams ?? config.tools?.documentParams,
maxParallelFunctionCalls: maxParallelCalls,
onFunctionCall: (functionCall) => {
// log(LogLevels.debug, 'Called function', functionCall)
},
}
: {
grammar: inputGrammar,
};
const initialTokenMeterState = instance.chat.sequence.tokenMeter.getState();
let completionResult;
while (true) {
// console.debug('before eval newChatHistory', JSON.stringify(newChatHistory, null, 2))
// console.debug('before eval newContextWindowChatHistory', JSON.stringify(newContextWindowChatHistory, null, 2))
const { functionCalls, lastEvaluation: currentLastEvaluation, metadata, } = await instance.chat.generateResponse(newChatHistory, {
signal,
stopOnAbortSignal: true, // this will make aborted completions resolve (with a partial response)
maxTokens: task.maxTokens ?? defaults.maxTokens,
temperature: task.temperature ?? defaults.temperature,
topP: task.topP ?? defaults.topP,
topK: task.topK ?? defaults.topK,
minP: task.minP ?? defaults.minP,
seed: task.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000),
tokenBias,
customStopTriggers,
trimWhitespaceSuffix: false,
...functionsOrGrammar,
repeatPenalty: {
lastTokens: task.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
frequencyPenalty: task.frequencyPenalty ?? defaults.frequencyPenalty,
presencePenalty: task.presencePenalty ?? defaults.presencePenalty,
},
contextShift: {
strategy: config.contextShiftStrategy,
lastEvaluationMetadata: lastEvaluation?.contextShiftMetadata,
},
lastEvaluationContextWindow: {
history: newContextWindowChatHistory,
minimumOverlapPercentageToPreventContextShift: 0.5,
},
onToken: (tokens) => {
const text = instance.model.detokenize(tokens);
if (task.onChunk) {
task.onChunk({
tokens,
text,
});
}
},
});
lastEvaluation = currentLastEvaluation;
newChatHistory = lastEvaluation.cleanHistory;
// console.debug('after eval newChatHistory', JSON.stringify(newChatHistory, null, 2))
// console.debug('after eval newContextWindowChatHistory', JSON.stringify(newContextWindowChatHistory, null, 2))
if (functionCalls) {
// find leading immediately invokable function calls (=have a handler function)
const invokableFunctionCalls = [];
for (const functionCall of functionCalls) {
const functionDef = functionDefinitions[functionCall.functionName];
if (functionDef.handler) {
invokableFunctionCalls.push(functionCall);
}
else {
break;
}
}
// if the model output text before the call, pass it on into the function handlers
// the response tokens will also be available via onChunk but this is more convenient
const lastMessage = newChatHistory[newChatHistory.length - 1];
const lastResponsePart = lastMessage.response[lastMessage.response.length - 1];
let leadingResponseText;
if (typeof lastResponsePart === 'string' && lastResponsePart) {
leadingResponseText = lastResponsePart;
}
// resolve function call results
const results = await Promise.all(invokableFunctionCalls.map(async (functionCall) => {
const functionDef = functionDefinitions[functionCall.functionName];
if (!functionDef) {
throw new Error(`The model tried to call undefined function "${functionCall.functionName}"`);
}
let functionCallResult = await functionDef.handler(functionCall.params, leadingResponseText);
log(LogLevels.debug, 'Function handler resolved', {
function: functionCall.functionName,
args: functionCall.params,
result: functionCallResult,
});
if (typeof functionCallResult !== 'string') {
if (functionsOrGrammar.functions && functionCallResult.preventFurtherCalls) {
// remove the function we just called from the list of available functions
functionsOrGrammar.functions = Object.fromEntries(Object.entries(functionsOrGrammar.functions).filter(([key]) => key !== functionCall.functionName));
if (Object.keys(functionsOrGrammar.functions).length === 0) {
// @ts-ignore
functionsOrGrammar.functions = undefined;
}
functionCallResult = functionCallResult.text;
}
}
return {
functionDef,
functionCall,
functionCallResult,
};
}));
newContextWindowChatHistory = lastEvaluation.contextWindow;
let startsNewChunk = supportsParallelFunctionCalling;
// add results to chat history in the order they were called
for (const callResult of results) {
newChatHistory = addFunctionCallToChatHistory({
chatHistory: newChatHistory,
functionName: callResult.functionCall.functionName,
functionDescription: callResult.functionDef.description,
callParams: callResult.functionCall.params,
callResult: callResult.functionCallResult,
rawCall: callResult.functionCall.raw,
startsNewChunk: startsNewChunk,
});
newContextWindowChatHistory = addFunctionCallToChatHistory({
chatHistory: newContextWindowChatHistory,
functionName: callResult.functionCall.functionName,
functionDescription: callResult.functionDef.description,
callParams: callResult.functionCall.params,
callResult: callResult.functionCallResult,
rawCall: callResult.functionCall.raw,
startsNewChunk: startsNewChunk,
});
startsNewChunk = false;
}
// if functions without handler have been called, return the calls as messages
const remainingFunctionCalls = functionCalls.slice(invokableFunctionCalls.length);
if (remainingFunctionCalls.length === 0) {
// if yes, continue with generation
lastEvaluation.cleanHistory = newChatHistory;
lastEvaluation.contextWindow = newContextWindowChatHistory;
continue;
}
else {
// if no, return the function calls and skip generation
instance.lastEvaluation = lastEvaluation;
instance.chatHistory = newChatHistory;
completionResult = {
responseText: null,
stopReason: 'functionCalls',
functionCalls: remainingFunctionCalls,
};
break;
}
}
// no function calls happened, we got a model response.
instance.lastEvaluation = lastEvaluation;
instance.chatHistory = newChatHistory;
const lastMessage = instance.chatHistory[instance.chatHistory.length - 1];
const responseText = lastMessage.response.filter((item) => typeof item === 'string').join('');
completionResult = {
responseText,
stopReason: metadata.stopReason,
};
break;
}
const assistantMessage = {
role: 'assistant',
content: completionResult.responseText || '',
};
if (completionResult.functionCalls) {
// TODO its possible that there are trailing immediately-evaluatable function calls.
// function call results need to be added in the order the functions were called, so
// we need to wait for the pending calls to complete before we can add the trailing calls.
// as is, these may never resolve
const pendingFunctionCalls = completionResult.functionCalls.filter((call) => {
const functionDef = functionDefinitions[call.functionName];
return !functionDef.handler;
});
// TODO write a test that triggers a parallel call to a handlerless function and to a function with one
const trailingFunctionCalls = completionResult.functionCalls.filter((call) => {
const functionDef = functionDefinitions[call.functionName];
return functionDef.handler;
});
if (trailingFunctionCalls.length) {
console.debug(trailingFunctionCalls);
log(LogLevels.warn, 'Trailing function calls not resolved');
}
assistantMessage.toolCalls = pendingFunctionCalls.map((call) => {
const callId = nanoid();
instance.pendingFunctionCalls[callId] = call;
log(LogLevels.debug, 'Saving pending tool call', {
id: callId,
function: call.functionName,
args: call.params,
});
return {
id: callId,
name: call.functionName,
parameters: call.params,
};
});
}
const tokenDifference = instance.chat.sequence.tokenMeter.diff(initialTokenMeterState);
// console.debug('final chatHistory', JSON.stringify(instance.chatHistory, null, 2))
// console.debug('final lastEvaluation', JSON.stringify(instance.lastEvaluation, null, 2))
return {
finishReason: mapFinishReason(completionResult.stopReason),
message: assistantMessage,
promptTokens: tokenDifference.usedInputTokens,
completionTokens: tokenDifference.usedOutputTokens,
contextTokens: instance.chat.sequence.contextTokens.length,
};
}
export async function processTextCompletionTask(task, ctx, signal) {
const { instance, resetContext, config, log } = ctx;
if (!task.prompt) {
throw new Error('Prompt is required for text completion.');
}
let completion;
let contextSequence;
if (resetContext && instance.contextSequence) {
instance.contextSequence.clearHistory();
}
if (!instance.completion || instance.completion.disposed) {
if (instance.contextSequence) {
contextSequence = instance.contextSequence;
}
else if (instance.context.sequencesLeft) {
contextSequence = instance.context.getSequence();
}
else {
throw new Error('No context sequence available');
}
instance.contextSequence = contextSequence;
completion = new LlamaCompletion({
contextSequence,
});
instance.completion = completion;
}
else {
completion = instance.completion;
contextSequence = instance.contextSequence;
}
if (!contextSequence || contextSequence.disposed) {
contextSequence = instance.context.getSequence();
instance.contextSequence = contextSequence;
completion = new LlamaCompletion({
contextSequence,
});
instance.completion = completion;
}
const stopGenerationTriggers = [];
const stopTrigger = task.stop ?? config.completionDefaults?.stop;
if (stopTrigger) {
stopGenerationTriggers.push(...stopTrigger.map((t) => [t]));
}
const initialTokenMeterState = contextSequence.tokenMeter.getState();
const defaults = config.completionDefaults ?? {};
const result = await completion.generateCompletionWithMeta(task.prompt, {
maxTokens: task.maxTokens ?? defaults.maxTokens,
temperature: task.temperature ?? defaults.temperature,
topP: task.topP ?? defaults.topP,
topK: task.topK ?? defaults.topK,
minP: task.minP ?? defaults.minP,
repeatPenalty: {
lastTokens: task.repeatPenaltyNum ?? defaults.repeatPenaltyNum,
frequencyPenalty: task.frequencyPenalty ?? defaults.frequencyPenalty,
presencePenalty: task.presencePenalty ?? defaults.presencePenalty,
},
signal: signal,
customStopTriggers: stopGenerationTriggers.length ? stopGenerationTriggers : undefined,
seed: task.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000),
onToken: (tokens) => {
const text = instance.model.detokenize(tokens);
if (task.onChunk) {
task.onChunk({
tokens,
text,
});
}
},
});
const tokenDifference = contextSequence.tokenMeter.diff(initialTokenMeterState);
return {
finishReason: mapFinishReason(result.metadata.stopReason),
text: result.response,
promptTokens: tokenDifference.usedInputTokens,
completionTokens: tokenDifference.usedOutputTokens,
contextTokens: contextSequence.contextTokens.length,
};
}
export async function processEmbeddingTask(task, ctx, signal) {
const { instance, config, log } = ctx;
if (!task.input) {
throw new Error('Input is required for embedding.');
}
const texts = [];
if (typeof task.input === 'string') {
texts.push(task.input);
}
else if (Array.isArray(task.input)) {
for (const input of task.input) {
if (typeof input === 'string') {
texts.push(input);
}
else if (input.type === 'text') {
texts.push(input.content);
}
else if (input.type === 'image') {
throw new Error('Image inputs not implemented.');
}
}
}
if (!instance.embeddingContext) {
instance.embeddingContext = await instance.model.createEmbeddingContext({
batchSize: config.batchSize,
createSignal: signal,
threads: config.device?.cpuThreads,
contextSize: config.contextSize,
});
}
// @ts-ignore - private property
const contextSize = instance.embeddingContext._llamaContext.contextSize;
const embeddings = [];
let inputTokens = 0;
for (const text of texts) {
let tokenizedInput = instance.model.tokenize(text);
if (tokenizedInput.length > contextSize) {
log(LogLevels.warn, 'Truncated input that exceeds context size');
tokenizedInput = tokenizedInput.slice(0, contextSize);
}
inputTokens += tokenizedInput.length;
const embedding = await instance.embeddingContext.getEmbeddingFor(tokenizedInput);
embeddings.push(new Float32Array(embedding.vector));
if (signal?.aborted) {
break;
}
}
return {
embeddings,
inputTokens,
};
}
//# sourceMappingURL=engine.js.map