UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

363 lines 13 kB
import path from 'node:path'; import fs from 'node:fs'; import { loadModel, createCompletion, createEmbedding, DEFAULT_MODEL_LIST_URL, } from 'gpt4all'; import { LogLevels } from '../../lib/logger.js'; import { downloadModelFile } from '../../lib/downloadModelFile.js'; import { acquireFileLock } from '../../lib/acquireFileLock.js'; import { validateModelFile } from '../../lib/validateModelFile.js'; import { createChatMessageArray } from './util.js'; export const autoGpu = true; export async function prepareModel({ config, log }, onProgress, signal) { fs.mkdirSync(path.dirname(config.location), { recursive: true }); const releaseFileLock = await acquireFileLock(config.location); if (signal?.aborted) { releaseFileLock(); return; } log(LogLevels.info, `Preparing gpt4all model at ${config.location}`, { model: config.id, }); let gpt4allMeta; let modelList; const modelMetaPath = path.join(path.dirname(config.location), 'models.json'); try { if (!fs.existsSync(modelMetaPath)) { const res = await fetch(DEFAULT_MODEL_LIST_URL); modelList = (await res.json()); fs.writeFileSync(modelMetaPath, JSON.stringify(modelList, null, 2)); } else { modelList = JSON.parse(fs.readFileSync(modelMetaPath, 'utf-8')); } const foundModelMeta = modelList.find((item) => { if (config.md5 && item.md5sum) { return item.md5sum === config.md5; } if (config.url && item.url) { return item.url === config.url; } return item.filename === path.basename(config.location); }); if (foundModelMeta) { gpt4allMeta = foundModelMeta; } const validationRes = await validateModelFile({ ...config, md5: config.md5 || gpt4allMeta?.md5sum, }); let modelMeta = validationRes.meta; if (signal?.aborted) { return; } if (validationRes.error) { if (!config.url) { throw new Error(`${validationRes.error} - No URL provided`); } log(LogLevels.info, 'Downloading', { model: config.id, url: config.url, location: config.location, error: validationRes.error, }); await downloadModelFile({ url: config.url, filePath: config.location, modelsCachePath: config.modelsCachePath, onProgress, signal, }); const revalidationRes = await validateModelFile({ ...config, md5: config.md5 || gpt4allMeta?.md5sum, }); if (revalidationRes.error) { throw new Error(`Downloaded files are invalid: ${revalidationRes.error}`); } modelMeta = revalidationRes.meta; } if (signal?.aborted) { return; } return { gpt4allMeta, ...modelMeta, }; } catch (error) { throw error; } finally { releaseFileLock(); } } export async function createInstance({ config, log }, signal) { log(LogLevels.info, `Load GPT4All model ${config.location}`); let device = config.device?.gpu ?? 'cpu'; if (typeof device === 'boolean') { device = device ? 'gpu' : 'cpu'; } else if (device === 'auto') { device = 'cpu'; } const loadOpts = { modelPath: path.dirname(config.location), // file: config.file, modelConfigFile: path.dirname(config.location) + '/models.json', allowDownload: false, device: device, ngl: config.device?.gpuLayers ?? 100, nCtx: config.contextSize ?? 2048, // verbose: true, // signal?: // TODO no way to cancel load }; let modelType; if (config.task === 'text-completion') { modelType = 'inference'; } else if (config.task === 'embedding') { modelType = 'embedding'; } else { throw new Error(`Unsupported task type: ${config.task}`); } const instance = await loadModel(path.basename(config.location), { ...loadOpts, type: modelType, }); if (config.device?.cpuThreads) { instance.llm.setThreadCount(config.device.cpuThreads); } if ('generate' in instance) { if (config.initialMessages?.length) { let messages = createChatMessageArray(config.initialMessages); let systemPrompt; if (messages[0].role === 'system') { systemPrompt = messages[0].content; messages = messages.slice(1); } await instance.createChatSession({ systemPrompt, messages, }); } else if (config.prefix) { await instance.generate(config.prefix, { nPredict: 0, }); } else { await instance.generate('', { nPredict: 0, }); } } return instance; } export async function disposeInstance(instance) { instance.dispose(); } export async function processTextCompletionTask(task, ctx, signal) { const { instance, config } = ctx; if (!('generate' in instance)) { throw new Error('Instance does not support text completion.'); } if (!task.prompt) { throw new Error('Prompt is required for text completion.'); } let finishReason = 'eogToken'; let suffixToRemove; const defaults = config.completionDefaults ?? {}; const stopTriggers = task.stop ?? defaults.stop ?? []; const includesStopTriggers = (text) => stopTriggers.find((t) => text.includes(t)); const result = await instance.generate(task.prompt, { // @ts-ignore special: true, // allows passing in raw prompt (including <|start|> etc.) promptTemplate: '%1', temperature: task.temperature ?? defaults.temperature, nPredict: task.maxTokens ?? defaults.maxTokens, topP: task.topP ?? defaults.topP, topK: task.topK ?? defaults.topK, minP: task.minP ?? defaults.minP, nBatch: config?.batchSize, repeatLastN: task.repeatPenaltyNum ?? defaults.repeatPenaltyNum, // repeat penalty is doing something different than both frequency and presence penalty // so not falling back to them here. repeatPenalty: task.repeatPenalty ?? defaults.repeatPenalty, // seed: args.seed, // https://github.com/nomic-ai/gpt4all/issues/1952 // @ts-ignore onResponseToken: (tokenId, text) => { const matchingTrigger = includesStopTriggers(text); if (matchingTrigger) { finishReason = 'stopTrigger'; suffixToRemove = text; return false; } if (task.onChunk) { task.onChunk({ text, tokens: [tokenId], }); } return !signal?.aborted; }, // @ts-ignore onResponseTokens: ({ tokenIds, text }) => { const matchingTrigger = includesStopTriggers(text); if (matchingTrigger) { finishReason = 'stopTrigger'; suffixToRemove = text; return false; } if (task.onChunk) { task.onChunk({ text, tokens: tokenIds, }); } return !signal?.aborted; }, }); if (result.tokensGenerated === task.maxTokens) { finishReason = 'maxTokens'; } let responseText = result.text; if (suffixToRemove) { responseText = responseText.slice(0, -suffixToRemove.length); } return { finishReason, text: responseText, promptTokens: result.tokensIngested, completionTokens: result.tokensGenerated, contextTokens: instance.activeChatSession?.promptContext.nPast ?? 0, }; } export async function processChatCompletionTask(task, ctx, signal) { const { config, instance, resetContext, log } = ctx; if (!('createChatSession' in instance)) { throw new Error('Instance does not support chat completion.'); } let session = instance.activeChatSession; if (!session || resetContext) { log(LogLevels.debug, 'Resetting chat context'); let messages = createChatMessageArray(task.messages); let systemPrompt; if (messages[0].role === 'system') { systemPrompt = messages[0].content; messages = messages.slice(1); } // drop last user message if (messages[messages.length - 1].role === 'user') { messages = messages.slice(0, -1); } session = await instance.createChatSession({ systemPrompt, messages, }); } const conversationMessages = createChatMessageArray(task.messages).filter((m) => m.role !== 'system'); const lastMessage = conversationMessages[conversationMessages.length - 1]; if (!(lastMessage.role === 'user' && lastMessage.content)) { throw new Error('Chat completions require a final user message.'); } const input = lastMessage.content; let finishReason = 'eogToken'; let suffixToRemove; const defaults = config.completionDefaults ?? {}; const stopTriggers = task.stop ?? defaults.stop ?? []; const includesStopTriggers = (text) => stopTriggers.find((t) => text.includes(t)); const result = await createCompletion(session, input, { temperature: task.temperature ?? defaults.temperature, nPredict: task.maxTokens ?? defaults.maxTokens, topP: task.topP ?? defaults.topP, topK: task.topK ?? defaults.topK, minP: task.minP ?? defaults.minP, nBatch: config.batchSize, repeatLastN: task.repeatPenaltyNum ?? defaults.repeatPenaltyNum, repeatPenalty: task.repeatPenalty ?? defaults.repeatPenalty, // seed: args.seed, // see https://github.com/nomic-ai/gpt4all/issues/1952 // @ts-ignore onResponseToken: (tokenId, text) => { const matchingTrigger = includesStopTriggers(text); if (matchingTrigger) { finishReason = 'stopTrigger'; suffixToRemove = text; return false; } if (task.onChunk) { task.onChunk({ text, tokens: [tokenId], }); } return !signal?.aborted; }, // @ts-ignore onResponseTokens: ({ tokenIds, text }) => { const matchingTrigger = includesStopTriggers(text); if (matchingTrigger) { finishReason = 'stopTrigger'; suffixToRemove = text; return false; } if (task.onChunk) { task.onChunk({ tokens: tokenIds, text, }); } return !signal?.aborted; }, }); if (result.usage.completion_tokens === task.maxTokens) { finishReason = 'maxTokens'; } let response = result.choices[0].message.content; if (suffixToRemove) { response = response.slice(0, -suffixToRemove.length); } return { finishReason, message: { role: 'assistant', content: response, }, promptTokens: result.usage.prompt_tokens, completionTokens: result.usage.completion_tokens, contextTokens: session.promptContext.nPast, }; } export async function processEmbeddingTask(task, ctx, signal) { const { instance, config } = ctx; if (!('embed' in instance)) { throw new Error('Instance does not support embedding.'); } if (!task.input) { throw new Error('Input is required for embedding.'); } const texts = []; if (typeof task.input === 'string') { texts.push(task.input); } else if (Array.isArray(task.input)) { for (const input of task.input) { if (typeof input === 'string') { texts.push(input); } else if (input.type === 'text') { texts.push(input.content); } else if (input.type === 'image') { throw new Error('Image inputs not implemented.'); } } } const res = await createEmbedding(instance, texts, { dimensionality: task.dimensions, }); return { embeddings: res.embeddings, inputTokens: res.n_prompt_tokens, }; } //# sourceMappingURL=engine.js.map