UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

873 lines (805 loc) 29.1 kB
import path from 'node:path' import fs from 'node:fs' import { nanoid } from 'nanoid' import { getLlama, LlamaOptions, LlamaChat, LlamaModel, LlamaContext, LlamaCompletion, LlamaLogLevel, LlamaChatResponseFunctionCall, TokenBias, Token, LlamaContextSequence, LlamaGrammar, ChatHistoryItem, LlamaChatResponse, ChatModelResponse, LlamaEmbeddingContext, defineChatSessionFunction, GbnfJsonSchema, ChatSessionModelFunction, createModelDownloader, GgufFileInfo, LlamaJsonSchemaGrammar, LLamaChatContextShiftOptions, LlamaContextOptions, ChatWrapper, ChatModelFunctionCall, } from 'node-llama-cpp' import { StopGenerationTrigger } from 'node-llama-cpp/dist/utils/StopGenerationDetector' import { ChatCompletionTaskResult, TextCompletionTaskResult, EngineContext, ToolDefinition, ToolCallResultMessage, AssistantMessage, EmbeddingTaskResult, FileDownloadProgress, ModelConfig, TextCompletionGrammar, ChatMessage, EngineTaskContext, EngineTextCompletionTaskContext, TextCompletionParamsBase, ChatCompletionTaskArgs, TextCompletionTaskArgs, EmbeddingTaskArgs, } from '#package/types/index.js' import { LogLevels } from '#package/lib/logger.js' import { flattenMessageTextContent } from '#package/lib/flattenMessageTextContent.js' import { acquireFileLock } from '#package/lib/acquireFileLock.js' import { getRandomNumber } from '#package/lib/util.js' import { validateModelFile } from '#package/lib/validateModelFile.js' import { createChatMessageArray, addFunctionCallToChatHistory, mapFinishReason } from './util.js' import { LlamaChatResult } from './types.js' export interface NodeLlamaCppInstance { model: LlamaModel context: LlamaContext chat?: LlamaChat chatHistory: ChatHistoryItem[] grammars: Record<string, LlamaGrammar> pendingFunctionCalls: Record<string, any> lastEvaluation?: LlamaChatResponse['lastEvaluation'] embeddingContext?: LlamaEmbeddingContext completion?: LlamaCompletion contextSequence: LlamaContextSequence chatWrapper?: ChatWrapper } export interface NodeLlamaCppModelMeta { gguf: GgufFileInfo } export interface NodeLlamaCppModelConfig extends ModelConfig { location: string grammars?: Record<string, TextCompletionGrammar> sha256?: string completionDefaults?: TextCompletionParamsBase initialMessages?: ChatMessage[] prefix?: string tools?: { definitions?: Record<string, ToolDefinition> documentParams?: boolean maxParallelCalls?: number } contextSize?: number batchSize?: number lora?: LlamaContextOptions['lora'] contextShiftStrategy?: LLamaChatContextShiftOptions['strategy'] chatWrapper?: ChatWrapper device?: { gpu?: boolean | 'auto' | (string & {}) gpuLayers?: number cpuThreads?: number memLock?: boolean } } export const autoGpu = true export async function prepareModel( { config, log }: EngineContext<NodeLlamaCppModelConfig>, onProgress?: (progress: FileDownloadProgress) => void, signal?: AbortSignal, ) { fs.mkdirSync(path.dirname(config.location), { recursive: true }) const releaseFileLock = await acquireFileLock(config.location, signal) if (signal?.aborted) { releaseFileLock() return } log(LogLevels.info, `Preparing node-llama-cpp model at ${config.location}`, { model: config.id, }) const downloadModel = async (url: string, validationResult: string) => { log(LogLevels.info, `Downloading model files`, { model: config.id, url: url, location: config.location, error: validationResult, }) const downloader = await createModelDownloader({ modelUrl: url, dirPath: path.dirname(config.location), fileName: path.basename(config.location), deleteTempFileOnCancel: false, onProgress: (status) => { if (onProgress) { onProgress({ file: config.location, loadedBytes: status.downloadedSize, totalBytes: status.totalSize, }) } }, }) await downloader.download() } try { if (signal?.aborted) { return } const validationRes = await validateModelFile(config) let modelMeta = validationRes.meta if (signal?.aborted) { return } if (validationRes.error) { if (!config.url) { throw new Error(`${validationRes.error} - No URL provided`) } await downloadModel(config.url, validationRes.error) const revalidationRes = await validateModelFile(config) if (revalidationRes.error) { throw new Error(`Downloaded files are invalid: ${revalidationRes.error}`) } modelMeta = revalidationRes.meta } return modelMeta } catch (err) { throw err } finally { releaseFileLock() } } export async function createInstance({ config, log }: EngineContext<NodeLlamaCppModelConfig>, signal?: AbortSignal) { log(LogLevels.debug, 'Load Llama model', config.device) // takes "auto" | "metal" | "cuda" | "vulkan" const gpuSetting = (config.device?.gpu ?? 'auto') as LlamaOptions['gpu'] const llama = await getLlama({ gpu: gpuSetting, // forwarding llama logger logLevel: LlamaLogLevel.debug, logger: (level, message) => { if (level === LlamaLogLevel.warn) { log(LogLevels.warn, message) } else if (level === LlamaLogLevel.error || level === LlamaLogLevel.fatal) { log(LogLevels.error, message) } else if (level === LlamaLogLevel.info || level === LlamaLogLevel.debug) { log(LogLevels.verbose, message) } }, }) const llamaGrammars: Record<string, LlamaGrammar> = { json: await LlamaGrammar.getFor(llama, 'json'), } if (config.grammars) { for (const key in config.grammars) { const input = config.grammars[key] if (typeof input === 'string') { llamaGrammars[key] = new LlamaGrammar(llama, { grammar: input, }) } else { // assume input is a JSON schema object llamaGrammars[key] = new LlamaJsonSchemaGrammar<GbnfJsonSchema>(llama, input as GbnfJsonSchema) } } } const llamaModel = await llama.loadModel({ modelPath: config.location, // full model absolute path loadSignal: signal, useMlock: config.device?.memLock ?? false, gpuLayers: config.device?.gpuLayers, // onLoadProgress: (percent) => {} }) const context = await llamaModel.createContext({ sequences: 1, lora: config.lora, threads: config.device?.cpuThreads, batchSize: config.batchSize, contextSize: config.contextSize, flashAttention: true, createSignal: signal, }) const instance: NodeLlamaCppInstance = { model: llamaModel, context, grammars: llamaGrammars, chat: undefined, chatHistory: [], pendingFunctionCalls: {}, lastEvaluation: undefined, completion: undefined, contextSequence: context.getSequence(), chatWrapper: config.chatWrapper, } if (config.initialMessages) { const initialChatHistory = createChatMessageArray(config.initialMessages) const chat = new LlamaChat({ contextSequence: instance.contextSequence!, chatWrapper: instance.chatWrapper, // autoDisposeSequence: true, }) let inputFunctions: Record<string, ChatSessionModelFunction> | undefined if (config.tools?.definitions && Object.keys(config.tools.definitions).length > 0) { const functionDefs = config.tools.definitions inputFunctions = {} for (const functionName in functionDefs) { const functionDef = functionDefs[functionName] inputFunctions[functionName] = defineChatSessionFunction<any>({ description: functionDef.description, params: functionDef.parameters, handler: functionDef.handler || (() => {}), }) as ChatSessionModelFunction } } const loadMessagesRes = await chat.loadChatAndCompleteUserMessage(initialChatHistory, { initialUserPrompt: '', functions: inputFunctions, documentFunctionParams: config.tools?.documentParams, }) instance.chat = chat instance.chatHistory = initialChatHistory instance.lastEvaluation = { cleanHistory: initialChatHistory, contextWindow: loadMessagesRes.lastEvaluation.contextWindow, contextShiftMetadata: loadMessagesRes.lastEvaluation.contextShiftMetadata, } } if (config.prefix) { const contextSequence = instance.contextSequence! const completion = new LlamaCompletion({ contextSequence: contextSequence, }) await completion.generateCompletion(config.prefix, { maxTokens: 0, }) instance.completion = completion instance.contextSequence = contextSequence } return instance } export async function disposeInstance(instance: NodeLlamaCppInstance) { await instance.model.dispose() } export async function processChatCompletionTask( task: ChatCompletionTaskArgs, ctx: EngineTextCompletionTaskContext<NodeLlamaCppInstance, NodeLlamaCppModelConfig, NodeLlamaCppModelMeta>, signal?: AbortSignal, ): Promise<ChatCompletionTaskResult> { const { instance, resetContext, config, log } = ctx if (!instance.chat || resetContext) { log(LogLevels.debug, 'Recreating chat context', { resetContext: resetContext, willDisposeChat: !!instance.chat, }) // if context reset is requested, dispose the chat instance if (instance.chat) { instance.chat.dispose() } let contextSequence = instance.contextSequence if (!contextSequence || contextSequence.disposed) { if (instance.context.sequencesLeft) { contextSequence = instance.context.getSequence() instance.contextSequence = contextSequence } else { throw new Error('No context sequence available') } } else { contextSequence.clearHistory() } instance.chat = new LlamaChat({ contextSequence: contextSequence, chatWrapper: instance.chatWrapper, // autoDisposeSequence: true, }) // reset state and reingest the conversation history instance.lastEvaluation = undefined instance.pendingFunctionCalls = {} instance.chatHistory = createChatMessageArray(task.messages) // drop last user message. its gonna be added later, after resolved function calls if (instance.chatHistory[instance.chatHistory.length - 1].type === 'user') { instance.chatHistory.pop() } } // set additional stop generation triggers for this completion const customStopTriggers: StopGenerationTrigger[] = [] const stopTrigger = task.stop ?? config.completionDefaults?.stop if (stopTrigger) { customStopTriggers.push(...stopTrigger.map((t) => [t])) } // setting up logit/token bias dictionary let tokenBias: TokenBias | undefined const completionTokenBias = task.tokenBias ?? config.completionDefaults?.tokenBias if (completionTokenBias) { tokenBias = new TokenBias(instance.model.tokenizer) for (const key in completionTokenBias) { const bias = completionTokenBias[key] / 10 const tokenId = parseInt(key) as Token if (!isNaN(tokenId)) { tokenBias.set(tokenId, bias) } else { tokenBias.set(key, bias) } } } // setting up available function definitions const functionDefinitions: Record<string, ToolDefinition> = { ...config.tools?.definitions, ...task.tools?.definitions, } // see if the user submitted any function call results const maxParallelCalls = task.tools?.maxParallelCalls ?? config.tools?.maxParallelCalls const chatWrapperSupportsParallelism = !!instance.chat.chatWrapper.settings.functions.parallelism const supportsParallelFunctionCalling = chatWrapperSupportsParallelism && !!maxParallelCalls const resolvedFunctionCalls: ChatModelFunctionCall[] = [] const functionCallResultMessages = task.messages.filter((m) => m.role === 'tool') as ToolCallResultMessage[] let startsNewChunk = supportsParallelFunctionCalling for (const message of functionCallResultMessages) { if (!instance.pendingFunctionCalls[message.callId]) { log(LogLevels.warn, `Received function result for non-existing call id "${message.callId}`) continue } log(LogLevels.debug, 'Resolving pending function call', { id: message.callId, result: message.content, }) const functionCall = instance.pendingFunctionCalls[message.callId] const functionDef = functionDefinitions[functionCall.functionName] const resolvedFunctionCall: ChatModelFunctionCall = { type: 'functionCall', name: functionCall.functionName, description: functionDef?.description, params: functionCall.params, result: message.content, rawCall: functionCall.raw, } if (startsNewChunk) { resolvedFunctionCall.startsNewChunk = true startsNewChunk = false } resolvedFunctionCalls.push(resolvedFunctionCall) delete instance.pendingFunctionCalls[message.callId] } // only grammar or functions can be used, not both. // currently ignoring function definitions if grammar is provided let inputGrammar: LlamaGrammar | undefined let inputFunctions: Record<string, ChatSessionModelFunction> | undefined if (task.grammar) { if (!instance.grammars[task.grammar]) { throw new Error(`Grammar "${task.grammar}" not found.`) } inputGrammar = instance.grammars[task.grammar] } else if (Object.keys(functionDefinitions).length > 0) { inputFunctions = {} for (const functionName in functionDefinitions) { const functionDef = functionDefinitions[functionName] inputFunctions[functionName] = defineChatSessionFunction<any>({ description: functionDef.description, params: functionDef.parameters, handler: functionDef.handler || (() => {}), }) } } let lastEvaluation: LlamaChatResponse['lastEvaluation'] | undefined = instance.lastEvaluation const appendResolvedFunctionCalls = (history: ChatHistoryItem[], response: ChatModelFunctionCall[]) => { const lastMessage = history[history.length - 1] // append to existing response item if last message in history is a model response if (lastMessage.type === 'model') { const lastMessageResponse = lastMessage as ChatModelResponse if (Array.isArray(response)) { lastMessageResponse.response.push(...response) // if we dont add a fresh empty message llama 3.2 3b will keep trying to call functions, not sure why this is history.push({ type: 'model', response: [], }) } return } // otherwise append a new one with the calls history.push({ type: 'model', response: response, }) } // if the incoming messages resolved any pending function calls, add them to history if (resolvedFunctionCalls.length) { appendResolvedFunctionCalls(instance.chatHistory, resolvedFunctionCalls) if (lastEvaluation?.contextWindow) { appendResolvedFunctionCalls(lastEvaluation.contextWindow, resolvedFunctionCalls) } } // add the new user message to the chat history let assistantPrefill: string = '' const lastMessage = task.messages[task.messages.length - 1] if (lastMessage.role === 'user' && lastMessage.content) { const newUserText = flattenMessageTextContent(lastMessage.content) if (newUserText) { instance.chatHistory.push({ type: 'user', text: newUserText, }) } } else if (lastMessage.role === 'assistant') { // use last message as prefill for response, if its an assistant message assistantPrefill = flattenMessageTextContent(lastMessage.content) } else if (!resolvedFunctionCalls.length) { log(LogLevels.warn, 'Last message is not valid for chat completion. This is likely a mistake.', lastMessage) throw new Error('Invalid last chat message') } const defaults = config.completionDefaults ?? {} let newChatHistory = instance.chatHistory.slice() let newContextWindowChatHistory = !lastEvaluation?.contextWindow ? undefined : instance.chatHistory.slice() if (instance.chatHistory[instance.chatHistory.length - 1].type !== 'model' || assistantPrefill) { const newModelResponse = assistantPrefill ? [assistantPrefill] : [] newChatHistory.push({ type: 'model', response: newModelResponse, }) if (newContextWindowChatHistory) { newContextWindowChatHistory.push({ type: 'model', response: newModelResponse, }) } } const functionsOrGrammar = inputFunctions ? { // clone the input funcs because the dict gets mutated in the loop below to enable preventFurtherCalls functions: { ...inputFunctions }, documentFunctionParams: task.tools?.documentParams ?? config.tools?.documentParams, maxParallelFunctionCalls: maxParallelCalls, onFunctionCall: (functionCall: LlamaChatResponseFunctionCall<any>) => { // log(LogLevels.debug, 'Called function', functionCall) }, } : { grammar: inputGrammar, } const initialTokenMeterState = instance.chat.sequence.tokenMeter.getState() let completionResult: LlamaChatResult while (true) { // console.debug('before eval newChatHistory', JSON.stringify(newChatHistory, null, 2)) // console.debug('before eval newContextWindowChatHistory', JSON.stringify(newContextWindowChatHistory, null, 2)) const { functionCalls, lastEvaluation: currentLastEvaluation, metadata, } = await instance.chat.generateResponse(newChatHistory, { signal, stopOnAbortSignal: true, // this will make aborted completions resolve (with a partial response) maxTokens: task.maxTokens ?? defaults.maxTokens, temperature: task.temperature ?? defaults.temperature, topP: task.topP ?? defaults.topP, topK: task.topK ?? defaults.topK, minP: task.minP ?? defaults.minP, seed: task.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000), tokenBias, customStopTriggers, trimWhitespaceSuffix: false, ...functionsOrGrammar, repeatPenalty: { lastTokens: task.repeatPenaltyNum ?? defaults.repeatPenaltyNum, frequencyPenalty: task.frequencyPenalty ?? defaults.frequencyPenalty, presencePenalty: task.presencePenalty ?? defaults.presencePenalty, }, contextShift: { strategy: config.contextShiftStrategy, lastEvaluationMetadata: lastEvaluation?.contextShiftMetadata, }, lastEvaluationContextWindow: { history: newContextWindowChatHistory, minimumOverlapPercentageToPreventContextShift: 0.5, }, onToken: (tokens) => { const text = instance.model.detokenize(tokens) if (task.onChunk) { task.onChunk({ tokens, text, }) } }, }) lastEvaluation = currentLastEvaluation newChatHistory = lastEvaluation.cleanHistory // console.debug('after eval newChatHistory', JSON.stringify(newChatHistory, null, 2)) // console.debug('after eval newContextWindowChatHistory', JSON.stringify(newContextWindowChatHistory, null, 2)) if (functionCalls) { // find leading immediately invokable function calls (=have a handler function) const invokableFunctionCalls = [] for (const functionCall of functionCalls) { const functionDef = functionDefinitions[functionCall.functionName] if (functionDef.handler) { invokableFunctionCalls.push(functionCall) } else { break } } // if the model output text before the call, pass it on into the function handlers // the response tokens will also be available via onChunk but this is more convenient const lastMessage = newChatHistory[newChatHistory.length - 1] as ChatModelResponse const lastResponsePart = lastMessage.response[lastMessage.response.length - 1] let leadingResponseText: string | undefined if (typeof lastResponsePart === 'string' && lastResponsePart) { leadingResponseText = lastResponsePart } // resolve function call results const results = await Promise.all( invokableFunctionCalls.map(async (functionCall) => { const functionDef = functionDefinitions[functionCall.functionName] if (!functionDef) { throw new Error(`The model tried to call undefined function "${functionCall.functionName}"`) } let functionCallResult = await functionDef.handler!(functionCall.params, leadingResponseText) log(LogLevels.debug, 'Function handler resolved', { function: functionCall.functionName, args: functionCall.params, result: functionCallResult, }) if (typeof functionCallResult !== 'string') { if (functionsOrGrammar.functions && functionCallResult.preventFurtherCalls) { // remove the function we just called from the list of available functions functionsOrGrammar.functions = Object.fromEntries( Object.entries(functionsOrGrammar.functions).filter(([key]) => key !== functionCall.functionName), ) if (Object.keys(functionsOrGrammar.functions).length === 0) { // @ts-ignore functionsOrGrammar.functions = undefined } functionCallResult = functionCallResult.text } } return { functionDef, functionCall, functionCallResult, } }), ) newContextWindowChatHistory = lastEvaluation.contextWindow let startsNewChunk = supportsParallelFunctionCalling // add results to chat history in the order they were called for (const callResult of results) { newChatHistory = addFunctionCallToChatHistory({ chatHistory: newChatHistory, functionName: callResult.functionCall.functionName, functionDescription: callResult.functionDef.description, callParams: callResult.functionCall.params, callResult: callResult.functionCallResult, rawCall: callResult.functionCall.raw, startsNewChunk: startsNewChunk, }) newContextWindowChatHistory = addFunctionCallToChatHistory({ chatHistory: newContextWindowChatHistory, functionName: callResult.functionCall.functionName, functionDescription: callResult.functionDef.description, callParams: callResult.functionCall.params, callResult: callResult.functionCallResult, rawCall: callResult.functionCall.raw, startsNewChunk: startsNewChunk, }) startsNewChunk = false } // if functions without handler have been called, return the calls as messages const remainingFunctionCalls = functionCalls.slice(invokableFunctionCalls.length) if (remainingFunctionCalls.length === 0) { // if yes, continue with generation lastEvaluation.cleanHistory = newChatHistory lastEvaluation.contextWindow = newContextWindowChatHistory! continue } else { // if no, return the function calls and skip generation instance.lastEvaluation = lastEvaluation instance.chatHistory = newChatHistory completionResult = { responseText: null, stopReason: 'functionCalls', functionCalls: remainingFunctionCalls, } break } } // no function calls happened, we got a model response. instance.lastEvaluation = lastEvaluation instance.chatHistory = newChatHistory const lastMessage = instance.chatHistory[instance.chatHistory.length - 1] as ChatModelResponse const responseText = lastMessage.response.filter((item: any) => typeof item === 'string').join('') completionResult = { responseText, stopReason: metadata.stopReason, } break } const assistantMessage: AssistantMessage = { role: 'assistant', content: completionResult.responseText || '', } if (completionResult.functionCalls) { // TODO its possible that there are trailing immediately-evaluatable function calls. // function call results need to be added in the order the functions were called, so // we need to wait for the pending calls to complete before we can add the trailing calls. // as is, these may never resolve const pendingFunctionCalls = completionResult.functionCalls.filter((call) => { const functionDef = functionDefinitions[call.functionName] return !functionDef.handler }) // TODO write a test that triggers a parallel call to a handlerless function and to a function with one const trailingFunctionCalls = completionResult.functionCalls.filter((call) => { const functionDef = functionDefinitions[call.functionName] return functionDef.handler }) if (trailingFunctionCalls.length) { console.debug(trailingFunctionCalls) log(LogLevels.warn, 'Trailing function calls not resolved') } assistantMessage.toolCalls = pendingFunctionCalls.map((call) => { const callId = nanoid() instance.pendingFunctionCalls[callId] = call log(LogLevels.debug, 'Saving pending tool call', { id: callId, function: call.functionName, args: call.params, }) return { id: callId, name: call.functionName, parameters: call.params, } }) } const tokenDifference = instance.chat.sequence.tokenMeter.diff(initialTokenMeterState) // console.debug('final chatHistory', JSON.stringify(instance.chatHistory, null, 2)) // console.debug('final lastEvaluation', JSON.stringify(instance.lastEvaluation, null, 2)) return { finishReason: mapFinishReason(completionResult.stopReason), message: assistantMessage, promptTokens: tokenDifference.usedInputTokens, completionTokens: tokenDifference.usedOutputTokens, contextTokens: instance.chat.sequence.contextTokens.length, } } export async function processTextCompletionTask( task: TextCompletionTaskArgs, ctx: EngineTextCompletionTaskContext<NodeLlamaCppInstance, NodeLlamaCppModelConfig, NodeLlamaCppModelMeta>, signal?: AbortSignal, ): Promise<TextCompletionTaskResult> { const { instance, resetContext, config, log } = ctx if (!task.prompt) { throw new Error('Prompt is required for text completion.') } let completion: LlamaCompletion let contextSequence: LlamaContextSequence if (resetContext && instance.contextSequence) { instance.contextSequence.clearHistory() } if (!instance.completion || instance.completion.disposed) { if (instance.contextSequence) { contextSequence = instance.contextSequence } else if (instance.context.sequencesLeft) { contextSequence = instance.context.getSequence() } else { throw new Error('No context sequence available') } instance.contextSequence = contextSequence completion = new LlamaCompletion({ contextSequence, }) instance.completion = completion } else { completion = instance.completion contextSequence = instance.contextSequence! } if (!contextSequence || contextSequence.disposed) { contextSequence = instance.context.getSequence() instance.contextSequence = contextSequence completion = new LlamaCompletion({ contextSequence, }) instance.completion = completion } const stopGenerationTriggers: StopGenerationTrigger[] = [] const stopTrigger = task.stop ?? config.completionDefaults?.stop if (stopTrigger) { stopGenerationTriggers.push(...stopTrigger.map((t) => [t])) } const initialTokenMeterState = contextSequence.tokenMeter.getState() const defaults = config.completionDefaults ?? {} const result = await completion.generateCompletionWithMeta(task.prompt, { maxTokens: task.maxTokens ?? defaults.maxTokens, temperature: task.temperature ?? defaults.temperature, topP: task.topP ?? defaults.topP, topK: task.topK ?? defaults.topK, minP: task.minP ?? defaults.minP, repeatPenalty: { lastTokens: task.repeatPenaltyNum ?? defaults.repeatPenaltyNum, frequencyPenalty: task.frequencyPenalty ?? defaults.frequencyPenalty, presencePenalty: task.presencePenalty ?? defaults.presencePenalty, }, signal: signal, customStopTriggers: stopGenerationTriggers.length ? stopGenerationTriggers : undefined, seed: task.seed ?? config.completionDefaults?.seed ?? getRandomNumber(0, 1000000), onToken: (tokens) => { const text = instance.model.detokenize(tokens) if (task.onChunk) { task.onChunk({ tokens, text, }) } }, }) const tokenDifference = contextSequence.tokenMeter.diff(initialTokenMeterState) return { finishReason: mapFinishReason(result.metadata.stopReason), text: result.response, promptTokens: tokenDifference.usedInputTokens, completionTokens: tokenDifference.usedOutputTokens, contextTokens: contextSequence.contextTokens.length, } } export async function processEmbeddingTask( task: EmbeddingTaskArgs, ctx: EngineTaskContext<NodeLlamaCppInstance, NodeLlamaCppModelConfig, NodeLlamaCppModelMeta>, signal?: AbortSignal, ): Promise<EmbeddingTaskResult> { const { instance, config, log } = ctx if (!task.input) { throw new Error('Input is required for embedding.') } const texts: string[] = [] if (typeof task.input === 'string') { texts.push(task.input) } else if (Array.isArray(task.input)) { for (const input of task.input) { if (typeof input === 'string') { texts.push(input) } else if (input.type === 'text') { texts.push(input.content) } else if (input.type === 'image') { throw new Error('Image inputs not implemented.') } } } if (!instance.embeddingContext) { instance.embeddingContext = await instance.model.createEmbeddingContext({ batchSize: config.batchSize, createSignal: signal, threads: config.device?.cpuThreads, contextSize: config.contextSize, }) } // @ts-ignore - private property const contextSize = instance.embeddingContext._llamaContext.contextSize const embeddings: Float32Array[] = [] let inputTokens = 0 for (const text of texts) { let tokenizedInput = instance.model.tokenize(text) if (tokenizedInput.length > contextSize) { log(LogLevels.warn, 'Truncated input that exceeds context size') tokenizedInput = tokenizedInput.slice(0, contextSize) } inputTokens += tokenizedInput.length const embedding = await instance.embeddingContext.getEmbeddingFor(tokenizedInput) embeddings.push(new Float32Array(embedding.vector)) if (signal?.aborted) { break } } return { embeddings, inputTokens, } }