UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

357 lines 14.6 kB
import { parseJSONRequestBody } from '../../../api/parseJSONRequestBody.js'; import { omitEmptyValues } from '../../../lib/util.js'; import { loadImageFromUrl } from '../../../lib/loadImage.js'; import { finishReasonMap } from '../enums.js'; async function prepareIncomingMessages(messages) { const downloadPromises = {}; const resultMessages = []; for (const message of messages) { const role = message.role; const resultMessage = { role, content: [], }; if (role === 'tool' && 'tool_call_id' in message) { resultMessage.callId = message.tool_call_id; } if (typeof message.content === 'string') { resultMessage.content.push({ type: 'text', text: message.content, }); } else if (Array.isArray(message.content)) { for (const part of message.content) { if (typeof part === 'string') { resultMessage.content.push({ type: 'text', text: part, }); } else if (part.type === 'text') { resultMessage.content.push({ type: 'text', text: part.text, }); } else if (part.type === 'image_url') { if (!downloadPromises[part.image_url.url]) { downloadPromises[part.image_url.url] = loadImageFromUrl(part.image_url.url); } const content = { type: 'image', }; resultMessage.content.push(content); downloadPromises[part.image_url.url].then((image) => { content.image = image; }); } else if (part.type === 'input_audio') { resultMessage.content.push({ type: 'audio', audio: part.input_audio, }); } else if (part.type === 'refusal') { resultMessage.content.push({ type: 'text', text: part.refusal, }); } } } else { throw new Error('Invalid message content'); } resultMessages.push(resultMessage); } await Promise.all(Object.values(downloadPromises)); return resultMessages; } function createResponseMessageContent(content) { if (!content) { return null; } if (typeof content === 'string') { return content; } if (!Array.isArray(content)) { throw new Error('Invalid response message content'); } let text = ''; for (const part of content) { if (part.type === 'text') { text += part.text; } // assistant may only respond with text in openai chat completions } return text; } export function createChatCompletionHandler(inferenceServer) { return async (req, res) => { let args; try { const body = await parseJSONRequestBody(req); args = body; } catch (e) { console.error(e); res.writeHead(400, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Invalid request' })); return; } // TODO ajv schema validation? if (!args.model || !args.messages) { res.writeHead(400, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Invalid request (need at least model and messages)' })); return; } if (!inferenceServer.modelExists(args.model)) { res.writeHead(400, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Model does not exist' })); return; } const controller = new AbortController(); req.on('close', () => { console.debug('Client closed connection'); controller.abort(); }); req.on('end', () => { console.debug('Client ended connection'); controller.abort(); }); req.on('aborted', () => { console.debug('Client aborted connection'); controller.abort(); }); req.on('error', () => { console.debug('Client error'); controller.abort(); }); try { let ssePing; if (args.stream) { res.writeHead(200, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', Connection: 'keep-alive', }); res.flushHeaders(); ssePing = setInterval(() => { res.write(':ping\n\n'); }, 30000); } let stop = args.stop ? args.stop : undefined; if (typeof stop === 'string') { stop = [stop]; } let completionGrammar; if (args.response_format) { if (args.response_format.type === 'json_object') { completionGrammar = 'json'; } } let toolDefinitions; if (args.tools) { const functionTools = args.tools .filter((tool) => tool.type === 'function') .map((tool) => { return { name: tool.function.name, description: tool.function.description, parameters: tool.function.parameters, }; }); if (functionTools.length) { if (!toolDefinitions) { toolDefinitions = {}; } for (const tool of functionTools) { toolDefinitions[tool.name] = { description: tool.description, parameters: tool.parameters, }; } } } const messages = await prepareIncomingMessages(args.messages); const completionReq = omitEmptyValues({ model: args.model, messages, temperature: args.temperature ? args.temperature : undefined, // stream: args.stream ? Boolean(args.stream) : false, maxTokens: args.max_tokens ? args.max_tokens : undefined, seed: args.seed ? args.seed : undefined, stop, frequencyPenalty: args.frequency_penalty ? args.frequency_penalty : undefined, presencePenalty: args.presence_penalty ? args.presence_penalty : undefined, topP: args.top_p ? args.top_p : undefined, tokenBias: args.logit_bias ? args.logit_bias : undefined, grammar: completionGrammar, tools: toolDefinitions ? { definitions: toolDefinitions } : undefined, // additional non-spec params repeatPenaltyNum: args.repeat_penalty_num ? args.repeat_penalty_num : undefined, minP: args.min_p ? args.min_p : undefined, topK: args.top_k ? args.top_k : undefined, }); const { instance, release } = await inferenceServer.requestInstance(completionReq, controller.signal); if (ssePing) { clearInterval(ssePing); } const task = instance.processChatCompletionTask({ ...completionReq, signal: controller.signal, onChunk: (chunk) => { if (args.stream) { const chunkData = { id: task.id, object: 'chat.completion.chunk', model: task.model, created: Math.floor(task.createdAt.getTime() / 1000), choices: [ { index: 0, delta: { role: 'assistant', content: chunk.text, }, logprobs: null, finish_reason: null, }, ], }; res.write(`data: ${JSON.stringify(chunkData)}\n\n`); } }, }); const result = await task.result; release(); if (args.stream) { if (result.finishReason === 'toolCalls') { // currently not possible to stream function calls // imitating a stream here by sending two chunks. makes it work with the openai client const streamedToolCallChunk = { id: task.id, object: 'chat.completion.chunk', model: task.model, created: Math.floor(task.createdAt.getTime() / 1000), choices: [ { index: 0, delta: { role: 'assistant', content: null, }, logprobs: null, finish_reason: result.finishReason ? finishReasonMap[result.finishReason] : 'stop', }, ], }; const toolCalls = result.message.toolCalls.map((call, index) => { return { index, id: call.id, type: 'function', function: { name: call.name, arguments: JSON.stringify(call.parameters), }, }; }); streamedToolCallChunk.choices[0].delta.tool_calls = toolCalls; res.write(`data: ${JSON.stringify(streamedToolCallChunk)}\n\n`); } if (args.stream_options?.include_usage) { const finalChunk = { id: task.id, object: 'chat.completion.chunk', model: task.model, created: Math.floor(task.createdAt.getTime() / 1000), system_fingerprint: instance.fingerprint, choices: [ { index: 0, delta: {}, logprobs: null, finish_reason: result.finishReason ? finishReasonMap[result.finishReason] : 'stop', }, ], usage: { prompt_tokens: result.promptTokens, completion_tokens: result.completionTokens, total_tokens: result.contextTokens, }, }; res.write(`data: ${JSON.stringify(finalChunk)}\n\n`); } res.write('data: [DONE]'); res.end(); } else { const response = { id: task.id, model: task.model, object: 'chat.completion', created: Math.floor(task.createdAt.getTime() / 1000), system_fingerprint: instance.fingerprint, choices: [ { index: 0, message: { role: 'assistant', content: createResponseMessageContent(result.message.content), refusal: null, }, logprobs: null, finish_reason: result.finishReason ? finishReasonMap[result.finishReason] : 'stop', }, ], usage: { prompt_tokens: result.promptTokens, completion_tokens: result.completionTokens, total_tokens: result.contextTokens, }, }; if ('toolCalls' in result.message && result.message.toolCalls?.length) { response.choices[0].message.tool_calls = result.message.toolCalls.map((call) => { return { id: call.id, type: 'function', function: { name: call.name, arguments: JSON.stringify(call.parameters), }, }; }); } res.writeHead(200, { 'Content-Type': 'application/json' }); res.end(JSON.stringify(response, null, 2)); } } catch (e) { console.error(e); if (args.stream) { res.write('data: [ERROR]'); } else { res.writeHead(500, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Internal server error' })); } } }; } //# sourceMappingURL=chat.js.map