UNPKG

inference-server

Version:

Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.

167 lines 6.95 kB
import { parseJSONRequestBody } from '../../../api/parseJSONRequestBody.js'; import { omitEmptyValues } from '../../../lib/util.js'; import { finishReasonMap } from '../enums.js'; export function createCompletionHandler(inferenceServer) { return async (req, res) => { let args; try { const body = await parseJSONRequestBody(req); args = body; } catch (e) { console.error(e); res.writeHead(400, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Invalid request' })); return; } // TODO ajv schema validation? if (!args.model || !args.prompt) { res.writeHead(400, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Invalid request' })); return; } if (!inferenceServer.modelExists(args.model)) { res.writeHead(400, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Invalid model' })); return; } const controller = new AbortController(); req.on('close', () => { console.debug('Client closed connection'); controller.abort(); }); req.on('end', () => { console.debug('Client ended connection'); controller.abort(); }); try { if (args.stream) { res.writeHead(200, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', Connection: 'keep-alive', }); res.flushHeaders(); } let prompt = args.prompt; if (typeof prompt !== 'string') { throw new Error('Prompt must be a string'); } let stop = args.stop ? args.stop : undefined; if (typeof stop === 'string') { stop = [stop]; } const completionReq = omitEmptyValues({ model: args.model, prompt: args.prompt, temperature: args.temperature ? args.temperature : undefined, maxTokens: args.max_tokens ? args.max_tokens : undefined, seed: args.seed ? args.seed : undefined, stop, frequencyPenalty: args.frequency_penalty ? args.frequency_penalty : undefined, presencePenalty: args.presence_penalty ? args.presence_penalty : undefined, tokenBias: args.logit_bias ? args.logit_bias : undefined, topP: args.top_p ? args.top_p : undefined, // additional non-spec params repeatPenaltyNum: args.repeat_penalty_num ? args.repeat_penalty_num : undefined, minP: args.min_p ? args.min_p : undefined, topK: args.top_k ? args.top_k : undefined, }); const { instance, release } = await inferenceServer.requestInstance(completionReq, controller.signal); const task = instance.processTextCompletionTask({ ...completionReq, signal: controller.signal, onChunk: (chunk) => { if (args.stream) { const chunkData = { id: task.id, model: task.model, object: 'text_completion', created: Math.floor(task.createdAt.getTime() / 1000), choices: [ { index: 0, text: chunk.text, logprobs: null, // @ts-ignore official api returns null here in the same case finish_reason: null, }, ], }; res.write(`data: ${JSON.stringify(chunkData)}\n\n`); } }, }); const result = await task.result; release(); if (args.stream) { if (args.stream_options?.include_usage) { const finalChunk = { id: task.id, model: task.model, object: 'text_completion', created: Math.floor(task.createdAt.getTime() / 1000), choices: [ { index: 0, text: '', logprobs: null, // @ts-ignore finish_reason: result.finishReason ? finishReasonMap[result.finishReason] : 'stop', }, ], }; res.write(`data: ${JSON.stringify(finalChunk)}\n\n`); } res.write('data: [DONE]'); res.end(); } else { const response = { id: task.id, model: task.model, object: 'text_completion', created: Math.floor(task.createdAt.getTime() / 1000), system_fingerprint: instance.fingerprint, choices: [ { index: 0, text: result.text, logprobs: null, // @ts-ignore finish_reason: result.finishReason ? finishReasonMap[result.finishReason] : 'stop', }, ], usage: { prompt_tokens: result.promptTokens, completion_tokens: result.completionTokens, total_tokens: result.contextTokens, }, }; res.writeHead(200, { 'Content-Type': 'application/json' }); res.end(JSON.stringify(response, null, 2)); } } catch (err) { console.error(err); if (args.stream) { res.write('data: [ERROR]'); } else { res.writeHead(500, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Internal server error' })); } } }; } //# sourceMappingURL=completions.js.map