UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

713 lines 30.9 kB
import path from 'node:path'; import fs from 'node:fs'; import { nanoid } from 'nanoid'; import { getLlama, LlamaChat, LlamaCompletion, LlamaLogLevel, TokenBias, LlamaGrammar, defineChatSessionFunction, createModelDownloader, LlamaJsonSchemaGrammar, } from 'node-llama-cpp'; import { LogLevels } from '../../lib/logger.js'; import { flattenMessageTextContent } from '../../lib/flattenMessageTextContent.js'; import { acquireFileLock } from '../../lib/acquireFileLock.js'; import { getRandomNumber } from '../../lib/util.js'; import { validateModelFile } from '../../lib/validateModelFile.js'; import { createChatMessageArray, addFunctionCallToChatHistory, mapFinishReason } from './util.js'; export const autoGpu = true; export async function prepareModel({ config, log }, onProgress, signal) { fs.mkdirSync(path.dirname(config.location), { recursive: true }); const releaseFileLock = await acquireFileLock(config.location, signal); if (signal?.aborted) { releaseFileLock(); return; } log(LogLevels.info, `Preparing node-llama-cpp model at ${config.location}`, { model: config.id, }); const downloadModel = async (url, validationResult) => { log(LogLevels.info, `Downloading model files`, { model: config.id, url: url, location: config.location, error: validationResult, }); const downloader = await createModelDownloader({ modelUrl: url, dirPath: path.dirname(config.location), fileName: path.basename(config.location), deleteTempFileOnCancel: false, onProgress: (status) => { if (onProgress) { onProgress({ file: config.location, loadedBytes: status.downloadedSize, totalBytes: status.totalSize, }); } }, }); await downloader.download(); }; try { if (signal?.aborted) { return; } const validationRes = await validateModelFile(config); let modelMeta = validationRes.meta; if (signal?.aborted) { return; } if (validationRes.error) { if (!config.url) { throw new Error(`${validationRes.error} - No URL provided`); } await downloadModel(config.url, validationRes.error); const revalidationRes = await validateModelFile(config); if (revalidationRes.error) { throw new Error(`Downloaded files are invalid: ${revalidationRes.error}`); } modelMeta = revalidationRes.meta; } return modelMeta; } catch (err) { throw err; } finally { releaseFileLock(); } } export async function createInstance({ config, log }, signal) { log(LogLevels.debug, 'Load Llama model', config.device); // takes "auto" | "metal" | "cuda" | "vulkan" const gpuSetting = (config.device?.gpu ?? 'auto'); const llama = await getLlama({ gpu: gpuSetting, // forwarding llama logger logLevel: LlamaLogLevel.debug, logger: (level, message) => { if (level === LlamaLogLevel.warn) { log(LogLevels.warn, message); } else if (level === LlamaLogLevel.error || level === LlamaLogLevel.fatal) { log(LogLevels.error, message); } else if (level === LlamaLogLevel.info || level === LlamaLogLevel.debug) { log(LogLevels.verbose, message); } }, }); const llamaGrammars = { json: await LlamaGrammar.getFor(llama, 'json'), }; if (config.grammars) { for (const key in config.grammars) { const input = config.grammars[key]; if (typeof input === 'string') { llamaGrammars[key] = new LlamaGrammar(llama, { grammar: input, }); } else { // assume input is a JSON schema object llamaGrammars[key] = new LlamaJsonSchemaGrammar(llama, input); } } } const llamaModel = await llama.loadModel({ modelPath: config.location, // full model absolute path loadSignal: signal, useMlock: config.device?.memLock ?? false, gpuLayers: config.device?.gpuLayers, // onLoadProgress: (percent) => {} }); const context = await llamaModel.createContext({ sequences: 1, lora: config.lora, threads: config.device?.cpuThreads, batchSize: config.batchSize, contextSize: config.contextSize, flashAttention: true, createSignal: signal, }); const instance = { model: llamaModel, context, grammars: llamaGrammars, chat: undefined, chatHistory: [], pendingFunctionCalls: {}, lastEvaluation: undefined, completion: undefined, contextSequence: context.getSequence(), chatWrapper: config.chatWrapper, }; if (config.initialMessages) { const initialChatHistory = createChatMessageArray(config.initialMessages); const chat = new LlamaChat({ contextSequence: instance.contextSequence, chatWrapper: instance.chatWrapper, // autoDisposeSequence: true, }); let inputFunctions; if (config.tools?.definitions && Object.keys(config.tools.definitions).length > 0) { const functionDefs = config.tools.definitions; inputFunctions = {}; for (const functionName in functionDefs) { const functionDef = functionDefs[functionName]; inputFunctions[functionName] = defineChatSessionFunction({ description: functionDef.description, params: functionDef.parameters, handler: functionDef.handler || (() => { }), }); } } const loadMessagesRes = await chat.loadChatAndCompleteUserMessage(initialChatHistory, { initialUserPrompt: '', functions: inputFunctions, documentFunctionParams: config.tools?.documentParams, }); instance.chat = chat; instance.chatHistory = initialChatHistory; instance.lastEvaluation = { cleanHistory: initialChatHistory, contextWindow: loadMessagesRes.lastEvaluation.contextWindow, contextShiftMetadata: loadMessagesRes.lastEvaluation.contextShiftMetadata, }; } if (config.prefix) { const contextSequence = instance.contextSequence; const completion = new LlamaCompletion({ contextSequence: contextSequence, }); await completion.generateCompletion(config.prefix, { maxTokens: 0, }); instance.completion = completion; instance.contextSequence = contextSequence; } return instance; } export async function disposeInstance(instance) { await instance.model.dispose(); } export async function processChatCompletionTask(task, ctx, signal) { const { instance, resetContext, config, log } = ctx; if (!instance.chat || resetContext) { log(LogLevels.debug, 'Recreating chat context', { resetContext: resetContext, willDisposeChat: !!instance.chat, }); // if context reset is requested, dispose the chat instance if (instance.chat) { instance.chat.dispose(); } let contextSequence = instance.contextSequence; if (!contextSequence || contextSequence.disposed) { if (instance.context.sequencesLeft) { contextSequence = instance.context.getSequence(); instance.contextSequence = contextSequence; } else { throw new Error('No context sequence available'); } } else { contextSequence.clearHistory(); } instance.chat = new LlamaChat({ contextSequence: contextSequence, chatWrapper: instance.chatWrapper, // autoDisposeSequence: true, }); // reset state and reingest the conversation history instance.lastEvaluation = undefined; instance.pendingFunctionCalls = {}; instance.chatHistory = createChatMessageArray(task.messages); // drop last user message. its gonna be added later, after resolved function calls if (instance.chatHistory[instance.chatHistory.length - 1].type === 'user') { instance.chatHistory.pop(); } } // set additional stop generation triggers for this completion const customStopTriggers = []; const stopTrigger = task.stop ?? config.completionDefaults?.stop; if (stopTrigger) { customStopTriggers.push(...stopTrigger.map((t) => [t])); } // setting up logit/token bias dictionary let tokenBias; const completionTokenBias = task.tokenBias ?? config.completionDefaults?.tokenBias; if (completionTokenBias) { tokenBias = new TokenBias(instance.model.tokenizer); for (const key in completionTokenBias) { const bias = completionTokenBias[key] / 10; const tokenId = parseInt(key); if (!isNaN(tokenId)) { tokenBias.set(tokenId, bias); } else { tokenBias.set(key, bias); } } } // setting up available function definitions const functionDefinitions = { ...config.tools?.definitions, ...task.tools?.definitions, }; // see if the user submitted any function call results const maxParallelCalls = task.tools?.maxParallelCalls ?? config.tools?.maxParallelCalls; const chatWrapperSupportsParallelism = !!instance.chat.chatWrapper.settings.functions.parallelism; const supportsParallelFunctionCalling = chatWrapperSupportsParallelism && !!maxParallelCalls; const resolvedFunctionCalls = []; const functionCallResultMessages = task.messages.filter((m) => m.role === 'tool'); let startsNewChunk = supportsParallelFunctionCalling; for (const message of functionCallResultMessages) { if (!instance.pendingFunctionCalls[message.callId]) { log(LogLevels.warn, `Received function result for non-existing call id "${message.callId}`); continue; } log(LogLevels.debug, 'Resolving pending function call', { id: message.callId, result: message.content, }); const functionCall = instance.pendingFunctionCalls[message.callId]; const functionDef = functionDefinitions[functionCall.functionName]; const resolvedFunctionCall = { type: 'functionCall', name: functionCall.functionName, description: functionDef?.description, params: functionCall.params, result: message.content, rawCall: functionCall.raw, }; if (startsNewChunk) { resolvedFunctionCall.startsNewChunk = true; startsNewChunk = false; } resolvedFunctionCalls.push(resolvedFunctionCall); delete instance.pendingFunctionCalls[message.callId]; } // only grammar or functions can be used, not both. // currently ignoring function definitions if grammar is provided let inputGrammar; let inputFunctions; if (task.grammar) { if (!instance.grammars[task.grammar]) { throw new Error(`Grammar "${task.grammar}" not found.`); } inputGrammar = instance.grammars[task.grammar]; } else if (Object.keys(functionDefinitions).length > 0) { inputFunctions = {}; for (const functionName in functionDefinitions) { const functionDef = functionDefinitions[functionName]; inputFunctions[functionName] = defineChatSessionFunction({ description: functionDef.description, params: functionDef.parameters, handler: functionDef.handler || (() => { }), }); } } let lastEvaluation = instance.lastEvaluation; const appendResolvedFunctionCalls = (history, response) => { const lastMessage = history[history.length - 1]; // append to existing response item if last message in history is a model response if (lastMessage.type === 'model') { const lastMessageResponse = lastMessage; if (Array.isArray(response)) { lastMessageResponse.response.push(...response); // if we dont add a fresh empty message llama 3.2 3b will keep trying to call functions, not sure why this is history.push({ type: 'model', response: [], }); } return; } // otherwise append a new one with the calls history.push({ type: 'model', response: response, }); }; // if the incoming messages resolved any pending function calls, add them to history if (resolvedFunctionCalls.length) { appendResolvedFunctionCalls(instance.chatHistory, resolvedFunctionCalls); if (lastEvaluation?.contextWindow) { appendResolvedFunctionCalls(lastEvaluation.contextWindow, resolvedFunctionCalls); } } // add the new user message to the chat history let assistantPrefill = ''; const lastMessage = task.messages[task.messages.length - 1]; if (lastMessage.role === 'user' && lastMessage.content) { const newUserText = flattenMessageTextContent(lastMessage.content); if (newUserText) { instance.chatHistory.push({ type: 'user', text: newUserText, }); } } else if (lastMessage.role === 'assistant') { // use last message as prefill for response, if its an assistant message assistantPrefill = flattenMessageTextContent(lastMessage.content); } else if (!resolvedFunctionCalls.length) { log(LogLevels.warn, 'Last message is not valid for chat completion. This is likely a mistake.', lastMessage); throw new Error('Invalid last chat message'); } const defaults = config.completionDefaults ?? {}; let newChatHistory = instance.chatHistory.slice(); let newContextWindowChatHistory = !lastEvaluation?.contextWindow ? undefined : instance.chatHistory.slice(); if (instance.chatHistory[instance.chatHistory.length - 1].type !== 'model' || assistantPrefill) { const newModelResponse = assistantPrefill ? [assistantPrefill] : []; newChatHistory.push({ type: 'model', response: newModelResponse, }); if (newContextWindowChatHistory) { newContextWindowChatHistory.push({ type: 'model', response: newModelResponse, }); } } const functionsOrGrammar = inputFunctions ? { // clone the input funcs because the dict gets mutated in the loop below to enable preventFurtherCalls functions: { ...inputFunctions }, documentFunctionParams: task.tools?.documentParams ?? config.tools?.documentParams, maxParallelFunctionCalls: maxParallelCalls, onFunctionCall: (functionCall) => { // log(LogLevels.debug, 'Called function', functionCall) }, } : { grammar: inputGrammar, }; const initialTokenMeterState = instance.chat.sequence.tokenMeter.getState(); let completionResult; while (true) { // console.debug('before eval newChatHistory', JSON.stringify(newChatHistory, null, 2)) // console.debug('before eval newContextWindowChatHistory', JSON.stringify(newContextWindowChatHistory, null, 2)) const { functionCalls, lastEvaluation: currentLastEvaluation, metadata, } = await instance.chat.generateResponse(newChatHistory, { signal, stopOnAbortSignal: true, // this will make aborted completions resolve (with a partial response) maxTokens: task.maxTokens ?? defaults.maxTokens, temperature: task.temperature ?? defaults.temperature, topP: task.topP ?? defaults.topP, topK: task.topK ?? defaults.topK, minP: task.minP ?? defaults.minP, seed: task.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000), tokenBias, customStopTriggers, trimWhitespaceSuffix: false, ...functionsOrGrammar, repeatPenalty: { lastTokens: task.repeatPenaltyNum ?? defaults.repeatPenaltyNum, frequencyPenalty: task.frequencyPenalty ?? defaults.frequencyPenalty, presencePenalty: task.presencePenalty ?? defaults.presencePenalty, }, contextShift: { strategy: config.contextShiftStrategy, lastEvaluationMetadata: lastEvaluation?.contextShiftMetadata, }, lastEvaluationContextWindow: { history: newContextWindowChatHistory, minimumOverlapPercentageToPreventContextShift: 0.5, }, onToken: (tokens) => { const text = instance.model.detokenize(tokens); if (task.onChunk) { task.onChunk({ tokens, text, }); } }, }); lastEvaluation = currentLastEvaluation; newChatHistory = lastEvaluation.cleanHistory; // console.debug('after eval newChatHistory', JSON.stringify(newChatHistory, null, 2)) // console.debug('after eval newContextWindowChatHistory', JSON.stringify(newContextWindowChatHistory, null, 2)) if (functionCalls) { // find leading immediately invokable function calls (=have a handler function) const invokableFunctionCalls = []; for (const functionCall of functionCalls) { const functionDef = functionDefinitions[functionCall.functionName]; if (functionDef.handler) { invokableFunctionCalls.push(functionCall); } else { break; } } // if the model output text before the call, pass it on into the function handlers // the response tokens will also be available via onChunk but this is more convenient const lastMessage = newChatHistory[newChatHistory.length - 1]; const lastResponsePart = lastMessage.response[lastMessage.response.length - 1]; let leadingResponseText; if (typeof lastResponsePart === 'string' && lastResponsePart) { leadingResponseText = lastResponsePart; } // resolve function call results const results = await Promise.all(invokableFunctionCalls.map(async (functionCall) => { const functionDef = functionDefinitions[functionCall.functionName]; if (!functionDef) { throw new Error(`The model tried to call undefined function "${functionCall.functionName}"`); } let functionCallResult = await functionDef.handler(functionCall.params, leadingResponseText); log(LogLevels.debug, 'Function handler resolved', { function: functionCall.functionName, args: functionCall.params, result: functionCallResult, }); if (typeof functionCallResult !== 'string') { if (functionsOrGrammar.functions && functionCallResult.preventFurtherCalls) { // remove the function we just called from the list of available functions functionsOrGrammar.functions = Object.fromEntries(Object.entries(functionsOrGrammar.functions).filter(([key]) => key !== functionCall.functionName)); if (Object.keys(functionsOrGrammar.functions).length === 0) { // @ts-ignore functionsOrGrammar.functions = undefined; } functionCallResult = functionCallResult.text; } } return { functionDef, functionCall, functionCallResult, }; })); newContextWindowChatHistory = lastEvaluation.contextWindow; let startsNewChunk = supportsParallelFunctionCalling; // add results to chat history in the order they were called for (const callResult of results) { newChatHistory = addFunctionCallToChatHistory({ chatHistory: newChatHistory, functionName: callResult.functionCall.functionName, functionDescription: callResult.functionDef.description, callParams: callResult.functionCall.params, callResult: callResult.functionCallResult, rawCall: callResult.functionCall.raw, startsNewChunk: startsNewChunk, }); newContextWindowChatHistory = addFunctionCallToChatHistory({ chatHistory: newContextWindowChatHistory, functionName: callResult.functionCall.functionName, functionDescription: callResult.functionDef.description, callParams: callResult.functionCall.params, callResult: callResult.functionCallResult, rawCall: callResult.functionCall.raw, startsNewChunk: startsNewChunk, }); startsNewChunk = false; } // if functions without handler have been called, return the calls as messages const remainingFunctionCalls = functionCalls.slice(invokableFunctionCalls.length); if (remainingFunctionCalls.length === 0) { // if yes, continue with generation lastEvaluation.cleanHistory = newChatHistory; lastEvaluation.contextWindow = newContextWindowChatHistory; continue; } else { // if no, return the function calls and skip generation instance.lastEvaluation = lastEvaluation; instance.chatHistory = newChatHistory; completionResult = { responseText: null, stopReason: 'functionCalls', functionCalls: remainingFunctionCalls, }; break; } } // no function calls happened, we got a model response. instance.lastEvaluation = lastEvaluation; instance.chatHistory = newChatHistory; const lastMessage = instance.chatHistory[instance.chatHistory.length - 1]; const responseText = lastMessage.response.filter((item) => typeof item === 'string').join(''); completionResult = { responseText, stopReason: metadata.stopReason, }; break; } const assistantMessage = { role: 'assistant', content: completionResult.responseText || '', }; if (completionResult.functionCalls) { // TODO its possible that there are trailing immediately-evaluatable function calls. // function call results need to be added in the order the functions were called, so // we need to wait for the pending calls to complete before we can add the trailing calls. // as is, these may never resolve const pendingFunctionCalls = completionResult.functionCalls.filter((call) => { const functionDef = functionDefinitions[call.functionName]; return !functionDef.handler; }); // TODO write a test that triggers a parallel call to a handlerless function and to a function with one const trailingFunctionCalls = completionResult.functionCalls.filter((call) => { const functionDef = functionDefinitions[call.functionName]; return functionDef.handler; }); if (trailingFunctionCalls.length) { console.debug(trailingFunctionCalls); log(LogLevels.warn, 'Trailing function calls not resolved'); } assistantMessage.toolCalls = pendingFunctionCalls.map((call) => { const callId = nanoid(); instance.pendingFunctionCalls[callId] = call; log(LogLevels.debug, 'Saving pending tool call', { id: callId, function: call.functionName, args: call.params, }); return { id: callId, name: call.functionName, parameters: call.params, }; }); } const tokenDifference = instance.chat.sequence.tokenMeter.diff(initialTokenMeterState); // console.debug('final chatHistory', JSON.stringify(instance.chatHistory, null, 2)) // console.debug('final lastEvaluation', JSON.stringify(instance.lastEvaluation, null, 2)) return { finishReason: mapFinishReason(completionResult.stopReason), message: assistantMessage, promptTokens: tokenDifference.usedInputTokens, completionTokens: tokenDifference.usedOutputTokens, contextTokens: instance.chat.sequence.contextTokens.length, }; } export async function processTextCompletionTask(task, ctx, signal) { const { instance, resetContext, config, log } = ctx; if (!task.prompt) { throw new Error('Prompt is required for text completion.'); } let completion; let contextSequence; if (resetContext && instance.contextSequence) { instance.contextSequence.clearHistory(); } if (!instance.completion || instance.completion.disposed) { if (instance.contextSequence) { contextSequence = instance.contextSequence; } else if (instance.context.sequencesLeft) { contextSequence = instance.context.getSequence(); } else { throw new Error('No context sequence available'); } instance.contextSequence = contextSequence; completion = new LlamaCompletion({ contextSequence, }); instance.completion = completion; } else { completion = instance.completion; contextSequence = instance.contextSequence; } if (!contextSequence || contextSequence.disposed) { contextSequence = instance.context.getSequence(); instance.contextSequence = contextSequence; completion = new LlamaCompletion({ contextSequence, }); instance.completion = completion; } const stopGenerationTriggers = []; const stopTrigger = task.stop ?? config.completionDefaults?.stop; if (stopTrigger) { stopGenerationTriggers.push(...stopTrigger.map((t) => [t])); } const initialTokenMeterState = contextSequence.tokenMeter.getState(); const defaults = config.completionDefaults ?? {}; const result = await completion.generateCompletionWithMeta(task.prompt, { maxTokens: task.maxTokens ?? defaults.maxTokens, temperature: task.temperature ?? defaults.temperature, topP: task.topP ?? defaults.topP, topK: task.topK ?? defaults.topK, minP: task.minP ?? defaults.minP, repeatPenalty: { lastTokens: task.repeatPenaltyNum ?? defaults.repeatPenaltyNum, frequencyPenalty: task.frequencyPenalty ?? defaults.frequencyPenalty, presencePenalty: task.presencePenalty ?? defaults.presencePenalty, }, signal: signal, customStopTriggers: stopGenerationTriggers.length ? stopGenerationTriggers : undefined, seed: task.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000), onToken: (tokens) => { const text = instance.model.detokenize(tokens); if (task.onChunk) { task.onChunk({ tokens, text, }); } }, }); const tokenDifference = contextSequence.tokenMeter.diff(initialTokenMeterState); return { finishReason: mapFinishReason(result.metadata.stopReason), text: result.response, promptTokens: tokenDifference.usedInputTokens, completionTokens: tokenDifference.usedOutputTokens, contextTokens: contextSequence.contextTokens.length, }; } export async function processEmbeddingTask(task, ctx, signal) { const { instance, config, log } = ctx; if (!task.input) { throw new Error('Input is required for embedding.'); } const texts = []; if (typeof task.input === 'string') { texts.push(task.input); } else if (Array.isArray(task.input)) { for (const input of task.input) { if (typeof input === 'string') { texts.push(input); } else if (input.type === 'text') { texts.push(input.content); } else if (input.type === 'image') { throw new Error('Image inputs not implemented.'); } } } if (!instance.embeddingContext) { instance.embeddingContext = await instance.model.createEmbeddingContext({ batchSize: config.batchSize, createSignal: signal, threads: config.device?.cpuThreads, contextSize: config.contextSize, }); } // @ts-ignore - private property const contextSize = instance.embeddingContext._llamaContext.contextSize; const embeddings = []; let inputTokens = 0; for (const text of texts) { let tokenizedInput = instance.model.tokenize(text); if (tokenizedInput.length > contextSize) { log(LogLevels.warn, 'Truncated input that exceeds context size'); tokenizedInput = tokenizedInput.slice(0, contextSize); } inputTokens += tokenizedInput.length; const embedding = await instance.embeddingContext.getEmbeddingFor(tokenizedInput); embeddings.push(new Float32Array(embedding.vector)); if (signal?.aborted) { break; } } return { embeddings, inputTokens, }; } //# sourceMappingURL=engine.js.map