UNPKG

@aj-archipelago/cortex

Version:

Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.

513 lines (449 loc) 16.9 kB
import OpenAIVisionPlugin from "./openAiVisionPlugin.js"; import logger from "../../lib/logger.js"; import axios from 'axios'; async function convertContentItem(item, maxImageSize, plugin) { let imageUrl = ""; try { switch (typeof item) { case "string": return item ? { type: "text", text: item } : null; case "object": switch (item.type) { case "text": return item.text ? { type: "text", text: item.text } : null; case "tool_use": return { type: "tool_use", id: item.id, name: item.name, input: typeof item.input === 'string' ? { query: item.input } : item.input }; case "tool_result": return { type: "tool_result", tool_use_id: item.tool_use_id, content: item.content }; case "image_url": imageUrl = item.url || item.image_url?.url || item.image_url; if (!imageUrl) { logger.warn("Could not parse image URL from content - skipping image content."); return null; } try { // First validate the image URL if (!await plugin.validateImageUrl(imageUrl)) { return null; } // Then fetch and convert to base64 if needed const urlData = imageUrl.startsWith("data:") ? imageUrl : await fetchImageAsDataURL(imageUrl); if (!urlData) { return null; } const base64Image = urlData.split(",")[1]; // Calculate actual decoded size of base64 data const base64Size = Buffer.from(base64Image, 'base64').length; if (base64Size > maxImageSize) { logger.warn(`Image size ${base64Size} bytes exceeds maximum allowed size ${maxImageSize} - skipping image content.`); return null; } const [, mimeType = "image/jpeg"] = urlData.match(/data:([a-zA-Z0-9]+\/[a-zA-Z0-9-.+]+).*,.*/) || []; return { type: "image", source: { type: "base64", media_type: mimeType, data: base64Image, }, }; } catch (error) { logger.error(`Failed to process image: ${error.message}`); return null; } default: return null; } default: return null; } } catch (e) { logger.warn(`Error converting content item: ${e}`); return null; } } // Fetch image and convert to base 64 data URL async function fetchImageAsDataURL(imageUrl) { try { // Get the actual image data const dataResponse = await axios.get(imageUrl, { timeout: 30000, responseType: 'arraybuffer', maxRedirects: 5 }); const contentType = dataResponse.headers['content-type']; const base64Image = Buffer.from(dataResponse.data).toString('base64'); return `data:${contentType};base64,${base64Image}`; } catch (e) { logger.error(`Failed to fetch image: ${imageUrl}. ${e}`); throw e; } } class Claude3VertexPlugin extends OpenAIVisionPlugin { parseResponse(data) { if (!data) { return data; } const { content } = data; // Handle tool use responses from Claude if (content && Array.isArray(content)) { const toolUses = content.filter(item => item.type === "tool_use"); if (toolUses.length > 0) { return { role: "assistant", content: "", tool_calls: toolUses.map(toolUse => ({ id: toolUse.id, type: "function", function: { name: toolUse.name, arguments: JSON.stringify(toolUse.input) } })) }; } // Handle regular text responses const textContent = content.find(item => item.type === "text"); if (textContent) { return textContent.text; } } return data; } // This code converts messages to the format required by the Claude Vertex API async convertMessagesToClaudeVertex(messages) { // Create a deep copy of the input messages const messagesCopy = JSON.parse(JSON.stringify(messages)); let system = ""; let imageCount = 0; const maxImages = 20; // Claude allows up to 20 images per request // Extract system messages const systemMessages = messagesCopy.filter(message => message.role === "system"); if (systemMessages.length > 0) { system = systemMessages.map(message => { if (Array.isArray(message.content)) { // For content arrays, extract text content and join return message.content .filter(item => item.type === 'text') .map(item => item.text) .join("\n"); } return message.content; }).join("\n"); } // Filter out system messages and empty messages let modifiedMessages = messagesCopy .filter(message => message.role !== "system") .map(message => { // Handle OpenAI tool calls format conversion to Claude format if (message.tool_calls) { return { role: message.role, content: message.tool_calls.map(toolCall => ({ type: "tool_use", id: toolCall.id, name: toolCall.function.name, input: JSON.parse(toolCall.function.arguments) })) }; } // Handle OpenAI tool response format conversion to Claude format if (message.role === "tool") { return { role: "user", content: [{ type: "tool_result", tool_use_id: message.tool_call_id, content: message.content }] }; } return { ...message }; }) .filter(message => { // Filter out messages with empty content if (!message.content) return false; if (Array.isArray(message.content) && message.content.length === 0) return false; return true; }); // Combine consecutive messages from the same author const combinedMessages = modifiedMessages.reduce((acc, message) => { if (acc.length === 0 || message.role !== acc[acc.length - 1].role) { acc.push({ ...message }); } else { const lastMessage = acc[acc.length - 1]; if (Array.isArray(lastMessage.content) && Array.isArray(message.content)) { lastMessage.content = [...lastMessage.content, ...message.content]; } else if (Array.isArray(lastMessage.content)) { lastMessage.content.push({ type: 'text', text: message.content }); } else if (Array.isArray(message.content)) { lastMessage.content = [{ type: 'text', text: lastMessage.content }, ...message.content]; } else { lastMessage.content += "\n" + message.content; } } return acc; }, []); // Ensure an odd number of messages const finalMessages = combinedMessages.length % 2 === 0 ? combinedMessages.slice(1) : combinedMessages; // Convert content items const claude3Messages = await Promise.all( finalMessages.map(async (message) => { const contentArray = Array.isArray(message.content) ? message.content : [message.content]; const claude3Content = await Promise.all(contentArray.map(async item => { const convertedItem = await convertContentItem(item, this.getModelMaxImageSize(), this); // Track image count if (convertedItem?.type === 'image') { imageCount++; if (imageCount > maxImages) { logger.warn(`Maximum number of images (${maxImages}) exceeded - skipping additional images.`); return null; } } return convertedItem; })); return { role: message.role, content: claude3Content.filter(Boolean), }; }) ); return { system, modifiedMessages: claude3Messages, }; } async getRequestParameters(text, parameters, prompt) { const requestParameters = await super.getRequestParameters( text, parameters, prompt ); const { system, modifiedMessages } = await this.convertMessagesToClaudeVertex(requestParameters.messages); requestParameters.system = system; requestParameters.messages = modifiedMessages; // Convert OpenAI tools format to Claude format if present if (parameters.tools) { requestParameters.tools = parameters.tools.map(tool => { if (tool.type === 'function') { return { name: tool.function.name, description: tool.function.description, input_schema: { type: "object", properties: tool.function.parameters.properties, required: tool.function.parameters.required || [] } }; } return tool; }); } if (parameters.tool_choice) { // Convert OpenAI tool_choice format to Claude format if (typeof parameters.tool_choice === 'string') { // Handle string values: auto, required, none if (parameters.tool_choice === 'required') { requestParameters.tool_choice = { type: 'any' }; // OpenAI's 'required' maps to Claude's 'any' } else if (parameters.tool_choice === 'auto') { requestParameters.tool_choice = { type: 'auto' }; } else if (parameters.tool_choice === 'none') { requestParameters.tool_choice = { type: 'none' }; } } else if (parameters.tool_choice.type === "function") { // Handle function-specific tool choice requestParameters.tool_choice = { type: "tool", name: parameters.tool_choice.function.name }; } } // If there are function calls in messages, generate tools block if (modifiedMessages?.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'tool_use') )) { const toolsMap = new Map(); // First add any existing tools from parameters to the map if (requestParameters.tools) { requestParameters.tools.forEach(tool => { toolsMap.set(tool.name, tool); }); } // Collect all unique tool uses from messages, only adding if not already present modifiedMessages.forEach(msg => { if (Array.isArray(msg.content)) { msg.content.forEach(item => { if (item.type === 'tool_use' && !toolsMap.has(item.name)) { toolsMap.set(item.name, { name: item.name, description: `Tool for ${item.name}`, input_schema: { type: "object", properties: item.input ? Object.keys(item.input).reduce((acc, key) => { acc[key] = { type: typeof item.input[key] === 'string' ? 'string' : 'object', description: `Parameter ${key} for ${item.name}` }; return acc; }, {}) : {}, required: item.input ? Object.keys(item.input) : [] } }); } }); } }); // Update the tools array with the combined unique tools requestParameters.tools = Array.from(toolsMap.values()); } requestParameters.max_tokens = this.getModelMaxReturnTokens(); requestParameters.anthropic_version = "vertex-2023-10-16"; return requestParameters; } // Override the logging function to display the messages and responses logRequestData(data, responseData, prompt) { const { stream, messages, system } = data; if (system) { const { length, units } = this.getLength(system); logger.info(`[system messages sent containing ${length} ${units}]`); logger.verbose(`${this.shortenContent(system)}`); } if (messages && messages.length > 1) { logger.info(`[chat request sent containing ${messages.length} messages]`); let totalLength = 0; let totalUnits; messages.forEach((message, index) => { const content = Array.isArray(message.content) ? message.content.map((item) => { if (item.source && item.source.type === 'base64') { item.source.data = '* base64 data truncated for log *'; } return JSON.stringify(item); }).join(", ") : message.content; const { length, units } = this.getLength(content); const preview = this.shortenContent(content); logger.verbose( `message ${index + 1}: role: ${ message.role }, ${units}: ${length}, content: "${preview}"` ); totalLength += length; totalUnits = units; }); logger.info(`[chat request contained ${totalLength} ${totalUnits}]`); } else { const message = messages[0]; const content = Array.isArray(message.content) ? message.content.map((item) => JSON.stringify(item)).join(", ") : message.content; const { length, units } = this.getLength(content); logger.info(`[request sent containing ${length} ${units}]`); logger.verbose(`${this.shortenContent(content)}`); } if (stream) { logger.info(`[response received as an SSE stream]`); } else { const parsedResponse = this.parseResponse(responseData); if (typeof parsedResponse === 'string') { const { length, units } = this.getLength(parsedResponse); logger.info(`[response received containing ${length} ${units}]`); logger.verbose(`${this.shortenContent(parsedResponse)}`); } else { logger.info(`[response received containing object]`); logger.verbose(`${JSON.stringify(parsedResponse)}`); } } prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`); } async execute(text, parameters, prompt, cortexRequest) { const requestParameters = await this.getRequestParameters( text, parameters, prompt, cortexRequest ); const { stream } = parameters; cortexRequest.data = { ...(cortexRequest.data || {}), ...requestParameters, }; cortexRequest.params = {}; // query params cortexRequest.stream = stream; cortexRequest.urlSuffix = cortexRequest.stream ? ":streamRawPredict?alt=sse" : ":rawPredict"; const gcpAuthTokenHelper = this.config.get("gcpAuthTokenHelper"); const authToken = await gcpAuthTokenHelper.getAccessToken(); cortexRequest.auth.Authorization = `Bearer ${authToken}`; return this.executeRequest(cortexRequest); } processStreamEvent(event, requestProgress) { const eventData = JSON.parse(event.data); const baseOpenAIResponse = { id: eventData.message?.id || `chatcmpl-${Date.now()}`, object: "chat.completion.chunk", created: Math.floor(Date.now() / 1000), model: this.modelName, choices: [{ index: 0, delta: {}, finish_reason: null }] }; switch (eventData.type) { case "message_start": // Initial message with role baseOpenAIResponse.choices[0].delta = { role: "assistant", content: "" }; requestProgress.data = JSON.stringify(baseOpenAIResponse); break; case "content_block_delta": if (eventData.delta.type === "text_delta") { baseOpenAIResponse.choices[0].delta = { content: eventData.delta.text }; requestProgress.data = JSON.stringify(baseOpenAIResponse); } break; case "message_stop": baseOpenAIResponse.choices[0].delta = {}; baseOpenAIResponse.choices[0].finish_reason = "stop"; requestProgress.data = JSON.stringify(baseOpenAIResponse); requestProgress.progress = 1; break; case "error": baseOpenAIResponse.choices[0].delta = { content: `\n\n*** ${eventData.error.message || eventData.error} ***` }; baseOpenAIResponse.choices[0].finish_reason = "error"; requestProgress.data = JSON.stringify(baseOpenAIResponse); requestProgress.progress = 1; break; // Ignore other event types as they don't map to OpenAI format case "content_block_start": case "content_block_stop": case "message_delta": case "ping": break; } return requestProgress; } } export default Claude3VertexPlugin;