@aj-archipelago/cortex
Version:
Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.
243 lines (206 loc) • 9.97 kB
JavaScript
// OpenAIChatPlugin.js
import ModelPlugin from './modelPlugin.js';
import logger from '../../lib/logger.js';
import CortexResponse from '../../lib/cortexResponse.js';
class OpenAIChatPlugin extends ModelPlugin {
constructor(pathway, model) {
super(pathway, model);
}
// convert to OpenAI messages array format if necessary
convertPalmToOpenAIMessages(context, examples, messages) {
let openAIMessages = [];
// Add context as a system message
if (context) {
openAIMessages.push({
role: 'system',
content: context,
});
}
// Add examples to the messages array
examples.forEach(example => {
openAIMessages.push({
role: example.input.author || 'user',
content: example.input.content,
});
openAIMessages.push({
role: example.output.author || 'assistant',
content: example.output.content,
});
});
// Add remaining messages to the messages array
messages.forEach(message => {
openAIMessages.push({
role: message.author,
content: message.content,
});
});
return openAIMessages;
}
// Set up parameters specific to the OpenAI Chat API
getRequestParameters(text, parameters, prompt) {
const { modelPromptText, modelPromptMessages, tokenLength, modelPrompt } = this.getCompiledPrompt(text, parameters, prompt);
let { stream, tools, functions } = parameters;
try {
tools = (tools && typeof tools === 'string' && tools !== '' ? JSON.parse(tools) : tools);
functions = (functions && typeof functions === 'string' && functions !== '' ? JSON.parse(functions) : functions);
} catch (e) {
tools = [];
functions = [];
}
// Define the model's max token length
const modelTargetTokenLength = this.getModelMaxPromptTokens();
let requestMessages = modelPromptMessages || [{ "role": "user", "content": modelPromptText }];
// Check if the messages are in Palm format and convert them to OpenAI format if necessary
const isPalmFormat = requestMessages.some(message => 'author' in message);
if (isPalmFormat) {
const context = modelPrompt.context || '';
const examples = modelPrompt.examples || [];
requestMessages = this.convertPalmToOpenAIMessages(context, examples, modelPromptMessages);
}
// Check if the token length exceeds the model's max token length
if (tokenLength > modelTargetTokenLength && this.promptParameters?.manageTokenLength) {
// Remove older messages until the token length is within the model's limit
requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelTargetTokenLength);
}
const requestParameters = {
messages: requestMessages,
temperature: this.temperature ?? 0.7,
...(stream !== undefined ? { stream } : {}),
...(tools && tools.length > 0 ? { tools, tool_choice: parameters.tool_choice || 'auto' } : {}),
...(functions && functions.length > 0 ? { functions } : {}),
};
return requestParameters;
}
// Assemble and execute the request to the OpenAI Chat API
async execute(text, parameters, prompt, cortexRequest) {
const requestParameters = this.getRequestParameters(text, parameters, prompt);
cortexRequest.data = { ...(cortexRequest.data || {}), ...requestParameters };
cortexRequest.params = {};
return this.executeRequest(cortexRequest);
}
// Parse the response from the OpenAI Chat API
parseResponse(data) {
if(!data) return "";
const { choices } = data;
if (!choices || !choices.length) {
return data;
}
// if we got a choices array back with more than one choice, return the whole array
if (choices.length > 1) {
return choices;
}
const choice = choices[0];
const message = choice.message;
if (!message) {
return null;
}
// Create standardized CortexResponse object
const cortexResponse = new CortexResponse({
output_text: message.content || "",
finishReason: choice.finish_reason || 'stop',
usage: data.usage || null,
metadata: {
model: this.modelName
}
});
// Handle tool calls
if (message.tool_calls) {
cortexResponse.toolCalls = message.tool_calls;
} else if (message.function_call) {
cortexResponse.functionCall = message.function_call;
}
return cortexResponse;
}
// Override processStreamEvent to handle OpenAI Chat streaming format
processStreamEvent(event, requestProgress) {
// check for end of stream or in-stream errors
if (event.data.trim() === '[DONE]') {
requestProgress.progress = 1;
} else {
let parsedMessage;
try {
parsedMessage = JSON.parse(event.data);
} catch (error) {
throw new Error(`Could not parse stream data: ${error}`);
}
// error can be in different places in the message
const streamError = parsedMessage?.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
if (streamError) {
throw new Error(streamError);
}
// Check if this is an empty/idle event that we should skip
const delta = parsedMessage?.choices?.[0]?.delta;
const isEmptyEvent = !delta ||
(Object.keys(delta).length === 0) ||
(Object.keys(delta).length === 1 && delta.content === '') ||
(Object.keys(delta).length === 1 && delta.tool_calls && delta.tool_calls.length === 0);
// Skip publishing empty events unless they have a finish_reason
const hasFinishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
if (isEmptyEvent && !hasFinishReason) {
// Return requestProgress without setting data to prevent publishing
return requestProgress;
}
// Set the data for non-empty events or events with finish_reason
requestProgress.data = event.data;
// finish reason can be in different places in the message
const finishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
if (finishReason) {
switch (finishReason.toLowerCase()) {
case 'safety':
const safetyRatings = JSON.stringify(parsedMessage?.candidates?.[0]?.safetyRatings) || '';
logger.warn(`Request ${this.requestId} was blocked by the safety filter. ${safetyRatings}`);
requestProgress.data = `\n\nResponse blocked by safety filter: ${safetyRatings}`;
requestProgress.progress = 1;
break;
default:
requestProgress.progress = 1;
break;
}
}
}
return requestProgress;
}
// Override the logging function to display the messages and responses
logRequestData(data, responseData, prompt) {
const { stream, messages } = data;
if (messages && messages.length > 1) {
logger.info(`[chat request sent containing ${messages.length} messages]`);
let totalLength = 0;
let totalUnits;
messages.forEach((message, index) => {
//message.content string or array
const content = message.content === undefined ? JSON.stringify(message) : (Array.isArray(message.content) ? message.content.map(item => {
return JSON.stringify(item);
}).join(', ') : message.content);
const { length, units } = this.getLength(content);
const displayContent = this.shortenContent(content);
logger.verbose(`message ${index + 1}: role: ${message.role}, ${units}: ${length}, content: "${displayContent}"`);
totalLength += length;
totalUnits = units;
});
logger.info(`[chat request contained ${totalLength} ${totalUnits}]`);
} else {
const message = messages[0];
const content = Array.isArray(message.content) ? message.content.map(item => {
return JSON.stringify(item);
}).join(', ') : message.content;
const { length, units } = this.getLength(content);
logger.info(`[request sent containing ${length} ${units}]`);
logger.verbose(`${this.shortenContent(content)}`);
}
if (stream) {
logger.info(`[response received as an SSE stream]`);
} else {
if (typeof responseData === 'string') {
const { length, units } = this.getLength(responseData);
logger.info(`[response received containing ${length} ${units}]`);
logger.verbose(`${this.shortenContent(responseData)}`);
} else {
logger.info(`[response received containing object]`);
logger.verbose(`${JSON.stringify(responseData)}`);
}
}
prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`);
}
}
export default OpenAIChatPlugin;