UNPKG

@aj-archipelago/cortex

Version:

Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.

456 lines (381 loc) 16.2 kB
// rest.js // Implement the REST endpoints for the pathways import pubsub from './pubsub.js'; import { requestState } from './requestState.js'; import { v4 as uuidv4 } from 'uuid'; import logger from '../lib/logger.js'; import { getSingleTokenChunks } from './chunker.js'; import axios from 'axios'; const getOllamaModels = async (ollamaUrl) => { try { const response = await axios.get(`${ollamaUrl}/api/tags`); return response.data.models.map(model => ({ id: `ollama-${model.name}`, object: 'model', owned_by: 'ollama', permission: '' })); } catch (error) { logger.error(`Error fetching Ollama models: ${error.message}`); return []; } }; const chunkTextIntoTokens = (() => { let partialToken = ''; return (text, isLast = false, useSingleTokenStream = false) => { const tokens = useSingleTokenStream ? getSingleTokenChunks(partialToken + text) : [text]; if (isLast) { partialToken = ''; return tokens; } partialToken = useSingleTokenStream ? tokens.pop() : ''; return tokens; }; })(); const processRestRequest = async (server, req, pathway, name, parameterMap = {}) => { const fieldVariableDefs = pathway.typeDef(pathway).restDefinition || []; const convertType = (value, type) => { if (type === 'Boolean') { return Boolean(value); } else if (type === 'Int') { return parseInt(value, 10); } else if (type === '[MultiMessage]' && Array.isArray(value)) { return value.map(msg => ({ ...msg, content: Array.isArray(msg.content) ? msg.content.map(item => JSON.stringify(item)) : msg.content })); } else { return value; } }; const variables = fieldVariableDefs.reduce((acc, variableDef) => { const requestBodyParamName = Object.keys(parameterMap).includes(variableDef.name) ? parameterMap[variableDef.name] : variableDef.name; if (Object.prototype.hasOwnProperty.call(req.body, requestBodyParamName)) { acc[variableDef.name] = convertType(req.body[requestBodyParamName], variableDef.type); } return acc; }, {}); const variableParams = fieldVariableDefs.map(({ name, type }) => `$${name}: ${type}`).join(', '); const queryArgs = fieldVariableDefs.map(({ name }) => `${name}: $${name}`).join(', '); const query = ` query ${name}(${variableParams}) { ${name}(${queryArgs}) { contextId previousResult result } } `; const result = await server.executeOperation({ query, variables }); // if we're streaming and there are errors, we return a standard error code if (Boolean(req.body.stream)) { if (result?.body?.singleResult?.errors) { return `[ERROR] ${result.body.singleResult.errors[0].message.split(';')[0]}`; } } // otherwise errors can just be returned as a string const resultText = result?.body?.singleResult?.data?.[name]?.result || result?.body?.singleResult?.errors?.[0]?.message || ""; return resultText; }; const processIncomingStream = (requestId, res, jsonResponse, pathway) => { const useSingleTokenStream = pathway.useSingleTokenStream || false; const startStream = (res) => { // Set the headers for streaming res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache'); res.setHeader('Connection', 'keep-alive'); res.flushHeaders(); } const finishStream = (res, jsonResponse) => { // Send the last partial token if it exists const lastTokens = chunkTextIntoTokens('', true, useSingleTokenStream); if (lastTokens.length > 0) { lastTokens.forEach(token => { fillJsonResponse(jsonResponse, token, null); sendStreamData(jsonResponse); }); } // If we haven't sent the stop message yet, do it now if (jsonResponse.choices?.[0]?.finish_reason !== "stop") { let jsonEndStream = JSON.parse(JSON.stringify(jsonResponse)); if (jsonEndStream.object === 'text_completion') { jsonEndStream.choices[0].index = 0; jsonEndStream.choices[0].finish_reason = "stop"; jsonEndStream.choices[0].text = ""; } else { jsonEndStream.choices[0].finish_reason = "stop"; jsonEndStream.choices[0].index = 0; jsonEndStream.choices[0].delta = {}; } sendStreamData(jsonEndStream); } sendStreamData('[DONE]'); res.end(); } const sendStreamData = (data) => { const dataString = (data==='[DONE]') ? data : JSON.stringify(data); if (!res.writableEnded) { res.write(`data: ${dataString}\n\n`); logger.debug(`REST SEND: data: ${dataString}`); } } const fillJsonResponse = (jsonResponse, inputText, _finishReason) => { jsonResponse.choices[0].finish_reason = null; if (jsonResponse.object === 'text_completion') { jsonResponse.choices[0].text = inputText; } else { jsonResponse.choices[0].delta.content = inputText; } return jsonResponse; } startStream(res); // If the requestId is an error message, we can't continue if (requestId.startsWith('[ERROR]')) { fillJsonResponse(jsonResponse, requestId, "stop"); sendStreamData(jsonResponse); finishStream(res, jsonResponse); return; } let subscription; subscription = pubsub.subscribe('REQUEST_PROGRESS', (data) => { const safeUnsubscribe = async () => { if (subscription) { try { const subPromiseResult = await subscription; if (subPromiseResult && pubsub.subscriptions?.[subPromiseResult]) { pubsub.unsubscribe(subPromiseResult); } } catch (error) { logger.warn(`Pubsub unsubscribe threw error: ${error}`); } } } const processStringData = (stringData) => { if (progress === 1 && stringData.trim() === "[DONE]") { fillJsonResponse(jsonResponse, stringData, "stop"); safeUnsubscribe(); finishStream(res, jsonResponse); return; } chunkTextIntoTokens(stringData, false, useSingleTokenStream).forEach(token => { fillJsonResponse(jsonResponse, token, null); sendStreamData(jsonResponse); }); if (progress === 1) { safeUnsubscribe(); finishStream(res, jsonResponse); } } if (data.requestProgress.requestId !== requestId) return; logger.debug(`REQUEST_PROGRESS received progress: ${data.requestProgress.progress}, data: ${data.requestProgress.data}`); const { progress, data: progressData } = data.requestProgress; try { const messageJson = JSON.parse(progressData); if (typeof messageJson === 'string') { processStringData(messageJson); return; } if (messageJson.error) { logger.error(`Stream error REST: ${messageJson?.error?.message || 'unknown error'}`); safeUnsubscribe(); finishStream(res, jsonResponse); return; } let content = ''; if (messageJson.choices) { const { text, delta } = messageJson.choices[0]; content = messageJson.object === 'text_completion' ? text : delta.content; } else if (messageJson.candidates) { content = messageJson.candidates[0].content.parts[0].text; } else if (messageJson.content) { content = messageJson.content?.[0]?.text || ''; } else { content = messageJson; } chunkTextIntoTokens(content, false, useSingleTokenStream).forEach(token => { fillJsonResponse(jsonResponse, token, null); sendStreamData(jsonResponse); }); } catch (error) { logger.debug(`progressData not JSON: ${progressData}`); if (typeof progressData === 'string') { processStringData(progressData); } else { fillJsonResponse(jsonResponse, progressData, "stop"); sendStreamData(jsonResponse); } } if (progress === 1) { safeUnsubscribe(); finishStream(res, jsonResponse); } }); // Fire the resolver for the async requestProgress logger.info(`Rest Endpoint starting async requestProgress, requestId: ${requestId}`); const { resolver, args } = requestState[requestId]; requestState[requestId].useRedis = false; requestState[requestId].started = true; resolver && resolver(args); return subscription; } function buildRestEndpoints(pathways, app, server, config) { if (config.get('enableRestEndpoints')) { const openAIChatModels = {}; const openAICompletionModels = {}; // Create normal REST endpoints or emulate OpenAI API per pathway for (const [name, pathway] of Object.entries(pathways)) { // Only expose endpoints for enabled pathways that explicitly want to expose a REST endpoint if (pathway.disabled) continue; // The pathway can either emulate an OpenAI endpoint or be a normal REST endpoint if (pathway.emulateOpenAIChatModel || pathway.emulateOpenAICompletionModel) { if (pathway.emulateOpenAIChatModel) { openAIChatModels[pathway.emulateOpenAIChatModel] = name; } if (pathway.emulateOpenAICompletionModel) { openAICompletionModels[pathway.emulateOpenAICompletionModel] = name; } } else { app.post(`/rest/${name}`, async (req, res) => { const resultText = await processRestRequest(server, req, pathway, name); res.send(resultText); }); } } // Create OpenAI compatible endpoints app.post('/v1/completions', async (req, res) => { const modelName = req.body.model || 'gpt-3.5-turbo'; let pathwayName; if (modelName.startsWith('ollama-')) { pathwayName = 'sys_ollama_completion'; req.body.ollamaModel = modelName.replace('ollama-', ''); } else { pathwayName = openAICompletionModels[modelName] || openAICompletionModels['*']; } if (!pathwayName) { res.status(404).json({ error: `Model ${modelName} not found.`, }); return; } const pathway = pathways[pathwayName]; const parameterMap = { text: 'prompt' }; const resultText = await processRestRequest(server, req, pathway, pathwayName, parameterMap); const jsonResponse = { id: `cmpl`, object: "text_completion", created: Date.now(), model: req.body.model, choices: [ { text: resultText, index: 0, logprobs: null, finish_reason: "stop" } ], }; // eslint-disable-next-line no-extra-boolean-cast if (Boolean(req.body.stream)) { jsonResponse.id = `cmpl-${resultText}`; jsonResponse.choices[0].finish_reason = null; processIncomingStream(resultText, res, jsonResponse, pathway); } else { const requestId = uuidv4(); jsonResponse.id = `cmpl-${requestId}`; res.json(jsonResponse); } }); app.post('/v1/chat/completions', async (req, res) => { const modelName = req.body.model || 'gpt-3.5-turbo'; let pathwayName; if (modelName.startsWith('ollama-')) { pathwayName = 'sys_ollama_chat'; req.body.ollamaModel = modelName.replace('ollama-', ''); } else { pathwayName = openAIChatModels[modelName] || openAIChatModels['*']; } if (!pathwayName) { res.status(404).json({ error: `Model ${modelName} not found.`, }); return; } const pathway = pathways[pathwayName]; const resultText = await processRestRequest(server, req, pathway, pathwayName); const jsonResponse = { id: `chatcmpl`, object: "chat.completion", created: Date.now(), model: req.body.model, choices: [ { message: { role: "assistant", content: resultText }, index: 0, finish_reason: "stop" } ], }; // eslint-disable-next-line no-extra-boolean-cast if (Boolean(req.body.stream)) { jsonResponse.id = `chatcmpl-${resultText}`; jsonResponse.choices[0] = { delta: { role: "assistant", content: resultText }, finish_reason: null } jsonResponse.object = "chat.completion.chunk"; processIncomingStream(resultText, res, jsonResponse, pathway); } else { const requestId = uuidv4(); jsonResponse.id = `chatcmpl-${requestId}`; res.json(jsonResponse); } }); app.get('/v1/models', async (req, res) => { const openAIModels = { ...openAIChatModels, ...openAICompletionModels }; const defaultModelId = 'gpt-3.5-turbo'; let models = []; // Get standard OpenAI-compatible models, filtering out our internal pathway models models = Object.entries(openAIModels) .filter(([modelId]) => !['ollama-chat', 'ollama-completion'].includes(modelId)) .map(([modelId]) => { if (modelId.includes('*')) { modelId = defaultModelId; } return { id: modelId, object: 'model', owned_by: 'openai', permission: '', }; }); // Get Ollama models if configured if (config.get('ollamaUrl')) { const ollamaModels = await getOllamaModels(config.get('ollamaUrl')); models = [...models, ...ollamaModels]; } // Filter out duplicates and sort models = models .filter((model, index, self) => { return index === self.findIndex((m) => m.id === model.id); }) .sort((a, b) => a.id.localeCompare(b.id)); res.json({ data: models, object: 'list', }); }); } } export { buildRestEndpoints };