@aj-archipelago/cortex
Version:
Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.
691 lines (587 loc) • 29.5 kB
JavaScript
// ModelPlugin.js
import HandleBars from '../../lib/handleBars.js';
import { executeRequest } from '../../lib/requestExecutor.js';
import { encode } from '../../lib/encodeCache.js';
import { getFirstNToken } from '../chunker.js';
import logger, { obscureUrlParams } from '../../lib/logger.js';
import { config } from '../../config.js';
import axios from 'axios';
const DEFAULT_MAX_TOKENS = 4096;
const DEFAULT_MAX_RETURN_TOKENS = 256;
const DEFAULT_PROMPT_TOKEN_RATIO = 1.0;
const DEFAULT_MAX_IMAGE_SIZE = 20 * 1024 * 1024; // 20MB default
const DEFAULT_ALLOWED_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp'];
class ModelPlugin {
constructor(pathway, model) {
this.modelName = model.name;
this.model = model;
this.config = config;
this.environmentVariables = config.getEnv();
this.temperature = pathway.temperature;
this.pathwayPrompt = pathway.prompt;
this.pathwayName = pathway.name;
this.promptParameters = {};
this.isMultiModal = false;
this.allowedMIMETypes = model.allowedMIMETypes || DEFAULT_ALLOWED_MIME_TYPES;
// Make all of the parameters defined on the pathway itself available to the prompt
for (const [k, v] of Object.entries(pathway)) {
this.promptParameters[k] = v?.default ?? v;
}
if (pathway.inputParameters) {
for (const [k, v] of Object.entries(pathway.inputParameters)) {
this.promptParameters[k] = v?.default ?? v;
}
}
this.requestCount = 0;
}
async validateImageUrl(url) {
if (url.startsWith('data:')) {
const [, mimeType = ""] = url.match(/data:([a-zA-Z0-9]+\/[a-zA-Z0-9-.+]+).*,.*/) || [];
return this.allowedMIMETypes.includes(mimeType);
}
try {
const headResponse = await axios.head(url, {
timeout: 30000,
maxRedirects: 5
});
const contentType = headResponse.headers['content-type'];
if (!contentType || !this.allowedMIMETypes.includes(contentType)) {
logger.warn(`Unsupported image type: ${contentType} - skipping image content.`);
return false;
}
return true;
} catch (e) {
logger.error(`Failed to validate image URL: ${url}. ${e}`);
return false;
}
}
safeGetEncodedLength(data) {
return encode(data).length;
}
truncateMessagesToTargetLength(messages, targetTokenLength = null, maxMessageTokenLength = Infinity) {
const truncationMarker = '[...]';
const truncationMarkerTokenLength = encode(truncationMarker).length;
const messageOverhead = 4; // Per-message overhead tokens
const conversationOverhead = 3; // Conversation formatting overhead
// Helper function to truncate text content
const truncateTextContent = (text, maxTokens) => {
if (this.safeGetEncodedLength(text) <= maxTokens) return text;
return getFirstNToken(text, maxTokens - truncationMarkerTokenLength) + truncationMarker;
};
// Helper function to truncate multimodal content
const truncateMultimodalContent = (content, maxTokens) => {
const newContent = [];
let contentTokensUsed = 0;
let truncationAdded = false;
for (let item of content) {
// Convert string items to text objects
if (typeof item === 'string') {
item = { type: 'text', text: item };
}
// Handle text items
if (item.type === 'text') {
if (contentTokensUsed < maxTokens) {
const remainingTokens = maxTokens - contentTokensUsed;
if (this.safeGetEncodedLength(item.text) <= remainingTokens) {
// Text fits completely
newContent.push(item);
contentTokensUsed += this.safeGetEncodedLength(item.text);
} else {
// Truncate text
const truncatedText = getFirstNToken(item.text, remainingTokens);
newContent.push({ type: 'text', text: truncatedText + truncationMarker });
contentTokensUsed += this.safeGetEncodedLength(truncatedText) + truncationMarkerTokenLength;
truncationAdded = true;
break;
}
}
}
// Handle image items - prioritize them but account for their token usage
else if (item.type === 'image_url') {
const imageTokens = 100; // Estimated token count for images
if (contentTokensUsed + imageTokens <= maxTokens) {
newContent.push(item);
contentTokensUsed += imageTokens;
}
}
// Other non-text content
else {
newContent.push(item);
}
}
// Add truncation marker if needed and not already added
if (content.length > newContent.length && !truncationAdded) {
newContent.push({ type: 'text', text: truncationMarker });
contentTokensUsed += truncationMarkerTokenLength;
}
return { content: newContent, tokensUsed: contentTokensUsed };
};
// Helper function to truncate any message content
const truncateMessageContent = (message, availableTokens, maxPerMessageTokens) => {
// Calculate max content tokens (minimum of available tokens or max per message)
const maxContentTokens = Math.min(
availableTokens,
maxPerMessageTokens - message.roleTokens - messageOverhead
);
const messageToAdd = { ...message };
delete messageToAdd.tokenLength;
delete messageToAdd.roleTokens;
delete messageToAdd.contentTokens;
// Keep originalIndex for sorting later
let contentTokensUsed = 0;
// Handle extreme constraints (zero or negative token availability)
if (maxContentTokens <= 0) {
// For extreme constraints, just add truncation marker or empty content
if (typeof message.content === 'string') {
messageToAdd.content = truncationMarker;
contentTokensUsed = truncationMarkerTokenLength;
} else if (Array.isArray(message.content)) {
messageToAdd.content = [{ type: 'text', text: truncationMarker }];
contentTokensUsed = truncationMarkerTokenLength;
}
const totalTokensUsed = message.roleTokens + contentTokensUsed + messageOverhead;
return { message: messageToAdd, tokensUsed: totalTokensUsed };
}
// Truncate text content
if (typeof message.content === 'string') {
// Leave room for truncation marker if needed
const contentSpace = Math.max(0, maxContentTokens);
messageToAdd.content = truncateTextContent(message.content, contentSpace);
contentTokensUsed = this.safeGetEncodedLength(messageToAdd.content);
}
// Handle multimodal content
else if (Array.isArray(message.content)) {
const result = truncateMultimodalContent(message.content, maxContentTokens);
messageToAdd.content = result.content;
contentTokensUsed = result.tokensUsed;
// Skip message if no content after truncation
if (result.content.length === 0) {
messageToAdd.content = [{ type: 'text', text: truncationMarker }];
contentTokensUsed = truncationMarkerTokenLength;
}
}
const totalTokensUsed = message.roleTokens + contentTokensUsed + messageOverhead;
return { message: messageToAdd, tokensUsed: totalTokensUsed };
};
// If no messages, return empty array
if (!messages || messages.length === 0) return [];
// If there's no target token length, get it from the model
if (!targetTokenLength) {
targetTokenLength = this.getModelMaxPromptTokens();
}
// First check if all messages already fit within the target length
const initialTokenCount = this.countMessagesTokens(messages);
if (initialTokenCount <= targetTokenLength && maxMessageTokenLength === Infinity) {
return messages;
}
// Calculate safety margin
const safetyMarginPercent = targetTokenLength > 1000 ? 0.05 : 0.02; // 5% or 2% for small targets
const safetyMarginMinimum = Math.min(20, Math.floor(targetTokenLength * 0.01)); // At most 1% for minimum
const safetyMargin = Math.max(safetyMarginMinimum, Math.round(targetTokenLength * safetyMarginPercent));
// Adjust targetTokenLength to account for overheads and safety margin
const effectiveTargetLength = Math.max(0, targetTokenLength - conversationOverhead - safetyMargin);
// Calculate token lengths for each message and track original index
const messagesWithTokens = messages.map((message, index) => {
// Count tokens for the role/author
const roleTokens = this.safeGetEncodedLength(message.role || message.author || "");
// Count tokens for content
const tokenLength = this.countMessagesTokens([message]);
return {
...message,
roleTokens: roleTokens,
contentTokens: tokenLength - roleTokens - messageOverhead,
tokenLength: tokenLength,
originalIndex: index // Keep track of original position
};
});
// Sort messages by priority: last message, then system messages (newest first), then others (newest first)
const lastMessage = messagesWithTokens.length > 0 ? messagesWithTokens[messagesWithTokens.length - 1] : null;
const systemMessages = messagesWithTokens
.filter(m => (m.role === 'system' || m.author === 'system') && m !== lastMessage)
.reverse();
const otherMessages = messagesWithTokens
.filter(m => (m.role !== 'system' && m.author !== 'system') && m !== lastMessage)
.reverse();
// Build prioritized array
const prioritizedMessages = [];
if (lastMessage) prioritizedMessages.push(lastMessage);
prioritizedMessages.push(...systemMessages, ...otherMessages);
// Track used tokens and build result
let usedTokens = 0;
const result = [];
// Process messages in priority order
for (const message of prioritizedMessages) {
// Calculate how many tokens we have available
const remainingTokens = effectiveTargetLength - usedTokens;
// If we have very few tokens left, skip this message
const minimumUsableTokens = 10;
if (remainingTokens < minimumUsableTokens) break;
const { message: truncatedMessage, tokensUsed } = truncateMessageContent(
message,
remainingTokens,
maxMessageTokenLength
);
if (truncatedMessage) {
result.push(truncatedMessage);
usedTokens += tokensUsed;
}
// If we're close to target token length, stop processing more messages
const cutoffThreshold = Math.min(20, Math.floor(effectiveTargetLength * 0.01));
if (effectiveTargetLength - usedTokens < cutoffThreshold) break;
}
// Handle edge case: No messages fit within the limit
if (result.length === 0 && prioritizedMessages.length > 0) {
// Force at least one message (highest priority) to fit
const highestPriorityMessage = prioritizedMessages[0];
const availableForContent = effectiveTargetLength - highestPriorityMessage.roleTokens - messageOverhead;
if (availableForContent > truncationMarkerTokenLength) {
const { message: truncatedMessage } = truncateMessageContent(
highestPriorityMessage,
availableForContent,
Infinity // No per-message limit in this case
);
if (truncatedMessage) {
result.push(truncatedMessage);
}
}
}
// Before returning, verify we're under the limit and fix if needed
const finalTokenCount = this.countMessagesTokens(result);
if (finalTokenCount > targetTokenLength && result.length > 0) {
const lastResult = result[result.length - 1];
// Aggressively truncate the last message more
if (typeof lastResult.content === 'string') {
const overage = finalTokenCount - targetTokenLength + safetyMargin/2;
const currentLength = this.safeGetEncodedLength(lastResult.content);
const newLength = Math.max(20, currentLength - overage);
lastResult.content = getFirstNToken(lastResult.content, newLength - truncationMarkerTokenLength) + truncationMarker;
}
// For multimodal content, just remove all but the first text item
else if (Array.isArray(lastResult.content)) {
const firstTextIndex = lastResult.content.findIndex(item => item.type === 'text');
if (firstTextIndex >= 0) {
const firstTextItem = lastResult.content[firstTextIndex];
// Keep only this text item and truncate it
const truncatedText = getFirstNToken(firstTextItem.text, 20) + truncationMarker;
lastResult.content = [{ type: 'text', text: truncatedText }];
}
}
}
// Sort by original index to restore original order
result.sort((a, b) => a.originalIndex - b.originalIndex);
// Remove originalIndex property from result objects
return result.map(message => {
const { originalIndex, ...messageWithoutIndex } = message;
return messageWithoutIndex;
});
}
//convert a messages array to a simple chatML format
messagesToChatML(messages, addAssistant = true) {
let output = "";
if (messages && messages.length) {
for (let message of messages) {
output += ((message.author || message.role) && (message.content || message.content === '')) ? `<|im_start|>${(message.author || message.role)}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
}
// you always want the assistant to respond next so add a
// directive for that
if (addAssistant) {
output += "<|im_start|>assistant\n";
}
}
return output;
}
// compile the Prompt
getCompiledPrompt(text, parameters, prompt) {
const mergeParameters = (promptParameters, parameters) => {
let result = { ...promptParameters };
for (let key in parameters) {
if (parameters[key] !== null) result[key] = parameters[key];
}
return result;
}
const combinedParameters = mergeParameters(this.promptParameters, parameters);
const modelPrompt = this.getModelPrompt(prompt, parameters);
let modelPromptText = '';
try {
modelPromptText = modelPrompt.prompt ? HandleBars.compile(modelPrompt.prompt)({ ...combinedParameters, text }) : '';
} catch (error) {
// If compilation fails, log the error and use the original prompt
logger.warn(`Handlebars compilation failed in getCompiledPrompt: ${error.message}. Using original text.`);
modelPromptText = modelPrompt.prompt || '';
}
const modelPromptMessages = this.getModelPromptMessages(modelPrompt, combinedParameters, text);
const modelPromptMessagesML = this.messagesToChatML(modelPromptMessages);
if (modelPromptMessagesML) {
return { modelPromptMessages, tokenLength: this.safeGetEncodedLength(modelPromptMessagesML), modelPrompt };
} else {
return { modelPromptText, tokenLength: this.safeGetEncodedLength(modelPromptText), modelPrompt };
}
}
getModelMaxTokenLength() {
return (this.promptParameters.maxTokenLength ?? this.model.maxTokenLength ?? DEFAULT_MAX_TOKENS);
}
getModelMaxPromptTokens() {
const hasMaxReturnTokens = this.promptParameters.maxReturnTokens !== undefined || this.model.maxReturnTokens !== undefined;
const maxPromptTokens = hasMaxReturnTokens
? this.getModelMaxTokenLength() - this.getModelMaxReturnTokens()
: Math.floor(this.getModelMaxTokenLength() * this.getPromptTokenRatio());
return maxPromptTokens;
}
getModelMaxReturnTokens() {
return (this.promptParameters.maxReturnTokens ?? this.model.maxReturnTokens ?? DEFAULT_MAX_RETURN_TOKENS);
}
getPromptTokenRatio() {
// TODO: Is this the right order of precedence? inputParameters should maybe be second?
return this.promptParameters.inputParameters?.tokenRatio ?? this.promptParameters.tokenRatio ?? DEFAULT_PROMPT_TOKEN_RATIO;
}
getModelPrompt(prompt, parameters) {
if (typeof(prompt) === 'function') {
return prompt(parameters);
} else {
return prompt;
}
}
getModelPromptMessages(modelPrompt, combinedParameters, text) {
if (!modelPrompt.messages) {
return null;
}
// First run handlebars compile on the pathway messages
const compiledMessages = modelPrompt.messages.map((message) => {
if (message.content && typeof message.content === 'string') {
try {
const compileText = HandleBars.compile(message.content);
return {
...message,
content: compileText({ ...combinedParameters, text }),
};
} catch (error) {
// If compilation fails, log the error and return the original content
logger.warn(`Handlebars compilation failed: ${error.message}. Using original text.`);
return message;
}
} else {
return message;
}
});
// Next add in any parameters that are referenced by name in the array
const expandedMessages = compiledMessages.flatMap((message) => {
if (typeof message === 'string') {
try {
const match = message.match(/{{(.+?)}}/);
const placeholder = match ? match[1] : null;
if (placeholder === null) {
return message;
} else {
return combinedParameters[placeholder] || [];
}
} catch (error) {
// If there's an error processing the string, return it as is
logger.warn(`Error processing message placeholder: ${error.message}. Using original text.`);
return message;
}
} else {
return [message];
}
});
// Clean up any null messages if they exist
expandedMessages.forEach((message) => {
if (typeof message === 'object' && message.content === null) {
message.content = '';
}
});
// Flatten content arrays for non-multimodal models
if (!this.isMultiModal) {
expandedMessages.forEach(message => {
if (Array.isArray(message?.content)) {
message.content = message.content.join("\n");
}
});
}
return expandedMessages;
}
requestUrl() {
const generateUrl = HandleBars.compile(this.model.url);
return generateUrl({ ...this.model, ...this.environmentVariables, ...this.config });
}
// Default response parsing
parseResponse(data) { return data; }
// Default simple logging
logRequestStart() {
this.requestCount++;
const logMessage = `>>> [${this.requestId}: ${this.pathwayName}.${this.requestCount}] request`;
const header = '>'.repeat(logMessage.length);
logger.info(`${header}`);
logger.info(`${logMessage}`);
logger.info(`>>> Making API request to ${obscureUrlParams(this.url)}`);
}
logAIRequestFinished(requestDuration) {
const logMessage = `<<< [${this.requestId}: ${this.pathwayName}] response - complete in ${requestDuration}ms - data:`;
const header = '<'.repeat(logMessage.length);
logger.info(`${header}`);
logger.info(`${logMessage}`);
}
getLength(data) {
const isProd = config.get('env') === 'production';
let length = 0;
let units = isProd ? 'characters' : 'tokens';
if (data) {
if (isProd || data.length > 5000) {
length = data.length;
units = 'characters';
} else {
length = encode(data).length;
}
}
return {length, units};
}
shortenContent(content, maxWords = 40) {
if (!content || typeof content !== 'string') {
return content;
}
const words = content.split(" ");
if (words.length <= maxWords || logger.level === 'debug') {
return content;
}
return words.slice(0, maxWords / 2).join(" ") +
" ... " +
words.slice(-maxWords / 2).join(" ");
}
logRequestData(data, responseData, prompt) {
const modelInput = data.prompt || (data.messages && data.messages[0].content) || (data.length > 0 && data[0].Text) || null;
if (modelInput) {
const { length, units } = this.getLength(modelInput);
logger.info(`[request sent containing ${length} ${units}]`);
logger.verbose(`${this.shortenContent(modelInput)}`);
}
const responseText = JSON.stringify(responseData);
const { length, units } = this.getLength(responseText);
logger.info(`[response received containing ${length} ${units}]`);
logger.verbose(`${this.shortenContent(responseText)}`);
prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`);
}
async executeRequest(cortexRequest) {
try {
const { url, data, pathway, requestId, prompt } = cortexRequest;
this.url = url;
this.requestId = requestId;
this.pathwayName = pathway.name;
this.pathwayPrompt = pathway.prompt;
cortexRequest.cache = config.get('enableCache') && (pathway.enableCache || pathway.temperature == 0);
this.logRequestStart();
const response = await executeRequest(cortexRequest);
// Add null check and default values for response
if (!response) {
throw new Error('Request failed - no response received');
}
const { data: responseData, duration: requestDuration } = response;
// Validate response data
if (!responseData) {
throw new Error('Request failed - no data in response');
}
const errorData = Array.isArray(responseData) ? responseData[0] : responseData;
if (errorData && errorData.error) {
const newError = new Error(errorData.error.message);
newError.data = errorData;
throw newError;
}
this.logAIRequestFinished(requestDuration || 0);
const parsedData = this.parseResponse(responseData);
this.logRequestData(data, parsedData, prompt);
return parsedData;
} catch (error) {
// Enhanced error logging
const errorMessage = error?.response?.data?.message
?? error?.response?.data?.error?.message
?? error?.message
?? String(error);
// Log the full error details for debugging
logger.error(`Error in executeRequest for ${this.pathwayName}: ${errorMessage}`);
if (error.response) {
logger.error(`Response status: ${error.response.status}`);
logger.error(`Response headers: ${JSON.stringify(error.response.headers)}`);
}
if (error.data) {
logger.error(`Additional error data: ${JSON.stringify(error.data)}`);
}
if (error.stack) {
logger.error(`Error stack: ${error.stack}`);
}
// Throw a more informative error
throw new Error(`Execution failed for ${this.pathwayName}: ${errorMessage}`);
}
}
processStreamEvent(event, requestProgress) {
// check for end of stream or in-stream errors
if (event.data.trim() === '[DONE]') {
requestProgress.progress = 1;
} else {
let parsedMessage;
try {
parsedMessage = JSON.parse(event.data);
requestProgress.data = event.data;
} catch (error) {
throw new Error(`Could not parse stream data: ${error}`);
}
// error can be in different places in the message
const streamError = parsedMessage?.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
if (streamError) {
throw new Error(streamError);
}
// finish reason can be in different places in the message
const finishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
if (finishReason) {
switch (finishReason.toLowerCase()) {
case 'safety':
const safetyRatings = JSON.stringify(parsedMessage?.candidates?.[0]?.safetyRatings) || '';
logger.warn(`Request ${this.requestId} was blocked by the safety filter. ${safetyRatings}`);
requestProgress.data = `\n\nResponse blocked by safety filter: ${safetyRatings}`;
requestProgress.progress = 1;
break;
default:
requestProgress.progress = 1;
break;
}
}
}
return requestProgress;
}
getModelMaxImageSize() {
return (this.promptParameters.maxImageSize ?? this.model.maxImageSize ?? DEFAULT_MAX_IMAGE_SIZE);
}
countMessagesTokens(messages) {
if (!messages || !Array.isArray(messages) || messages.length === 0) {
return 0;
}
let totalTokens = 0;
for (const message of messages) {
// Count tokens for role/author
const role = message.role || message.author || "";
if (role) {
totalTokens += this.safeGetEncodedLength(role);
}
// Count tokens for content
if (typeof message.content === 'string') {
totalTokens += this.safeGetEncodedLength(message.content);
} else if (Array.isArray(message.content)) {
// Handle multimodal content
for (const item of message.content) {
// item can be a string or an object
if (typeof item === 'string') {
totalTokens += this.safeGetEncodedLength(item);
} else if (item.type === 'text') {
totalTokens += this.safeGetEncodedLength(item.text);
} else if (item.type === 'image_url') {
// Most models use ~85-130 tokens per image, but this varies by model
totalTokens += 100;
}
}
}
// Add per-message overhead (typically 3-4 tokens per message)
totalTokens += 4;
}
// Add conversation formatting overhead
totalTokens += 3;
return totalTokens;
}
}
export default ModelPlugin;