@aj-archipelago/cortex
Version:
Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.
346 lines (299 loc) • 12.5 kB
JavaScript
import test from 'ava';
import Claude3VertexPlugin from '../server/plugins/claude3VertexPlugin.js';
import { mockPathwayResolverMessages } from './mocks.js';
import { config } from '../config.js';
import fs from 'fs';
import path from 'path';
// Helper function to load test data from files
function loadTestData(filename) {
try {
const filePath = path.join(process.cwd(), 'tests', 'data', filename);
return fs.readFileSync(filePath, 'utf8');
} catch (error) {
console.error(`Error loading test data file ${filename}:`, error);
// Return a smaller fallback test string if file loading fails
return 'a '.repeat(1000);
}
}
const { pathway, model } = mockPathwayResolverMessages;
test('constructor', (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
t.is(plugin.config, config);
t.is(plugin.pathwayPrompt, mockPathwayResolverMessages.pathway.prompt);
});
test('getRequestParameters', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
const text = 'Help me';
const parameters = { name: 'John', age: 30, stream: false };
const prompt = mockPathwayResolverMessages.pathway.prompt;
const result = await plugin.getRequestParameters(text, parameters, prompt);
t.deepEqual(result, {
system: '',
messages: [
{
role: "user",
content: [
{
type: "text",
text: "Translate this: Help me",
},
],
},
{
role: "assistant",
content: [
{
type: "text",
text: "Translating: Help me",
},
],
},
{
role: "user",
content: [
{
type: "text",
text: "Nice work!",
},
],
},
],
max_tokens: plugin.getModelMaxReturnTokens(),
anthropic_version: 'vertex-2023-10-16',
stream: false,
temperature: 0.7,
});
});
test('getRequestParameters with long message in chatHistory', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
const text = 'Final message';
// Load long content from file
const longContent = loadTestData('largecontent.txt');
// Set up chatHistory with a long message
const chatHistory = [
{ role: 'user', content: 'Short message' },
{ role: 'assistant', content: 'Short response' },
{ role: 'user', content: longContent },
{ role: 'assistant', content: 'Long content response' },
{ role: 'user', content: 'Final message' }
];
// Create a custom prompt that includes the chatHistory
const prompt = {
...mockPathwayResolverMessages.pathway.prompt,
messages: chatHistory
};
const parameters = { stream: false };
plugin.promptParameters.manageTokenLength = true;
const result = await plugin.getRequestParameters(text, parameters, prompt);
// Verify we have messages in the result
t.truthy(result.messages);
// Check that the long message was truncated (should be shorter than original)
const userMessages = result.messages.filter(msg =>
msg.role === 'user' &&
msg.content[0].type === 'text'
);
// Verify we have user messages in the result
t.true(userMessages.length > 0, 'Should include user messages');
// Find the long message that was truncated
const longMessage = userMessages.find(msg =>
msg.content[0].text.length < longContent.length &&
msg.content[0].text.length > 100 // Ensure it's the long message, not other short ones
);
// Verify the long message was truncated
t.truthy(longMessage, 'Long user message should be truncated');
t.true(longMessage.content[0].text.length < longContent.length, 'Truncated message should be shorter than original');
// Verify the final user input message is included
const finalInputMessage = result.messages.find(msg =>
msg.role === 'user' &&
msg.content[0].type === 'text' &&
msg.content[0].text.includes(text)
);
t.truthy(finalInputMessage, 'Final user input should be included');
// Log token counts for debugging/verification
console.log(`Original content length: ${longContent.length} chars`);
console.log(`Truncated content length: ${longMessage.content[0].text.length} chars`);
});
test('parseResponse', (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
const dataWithTextContent = {
content: [
{ type: 'text', text: 'Hello, World!' }
]
};
const resultWithTextContent = plugin.parseResponse(dataWithTextContent);
t.is(resultWithTextContent, 'Hello, World!');
const dataWithoutTextContent = {
content: [
{ type: 'image', url: 'http://example.com/image.jpg' }
]
};
const resultWithoutTextContent = plugin.parseResponse(dataWithoutTextContent);
t.deepEqual(resultWithoutTextContent, dataWithoutTextContent);
const dataWithoutContent = {};
const resultWithoutContent = plugin.parseResponse(dataWithoutContent);
t.deepEqual(resultWithoutContent, dataWithoutContent);
const dataNull = null;
const resultNull = plugin.parseResponse(dataNull);
t.is(resultNull, dataNull);
});
test('convertMessagesToClaudeVertex text message', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with text message
let messages = [
{ role: 'system', content: 'System message' },
{ role: 'user', content: 'User message' },
{ role: 'assistant', content: 'Assistant message' },
{ role: 'user', content: 'User message 2' },
];
let output = await plugin.convertMessagesToClaudeVertex(messages);
t.deepEqual(output, {
system: 'System message',
modifiedMessages: [
{
role: "user",
content: [
{
type: "text",
text: "User message",
},
],
},
{
role: "assistant",
content: [
{
type: "text",
text: "Assistant message",
},
],
},
{
role: "user",
content: [
{
type: "text",
text: "User message 2",
},
],
},
],
});
});
test('convertMessagesToClaudeVertex image_url message', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with image_url message
const messages = [
{
role: 'assistant',
content: {
type: 'image_url',
// Define image_url, make sure it's accessible and supported MIME type
image_url: 'https://static.toiimg.com/thumb/msid-102827471,width-1280,height-720,resizemode-4/102827471.jpg'
}
}
];
const output = await plugin.convertMessagesToClaudeVertex(messages);
// Define a regex for base64 validation
const base64Regex = /^[A-Za-z0-9+/]+={0,2}$/;
const base64Data = output.modifiedMessages[0].content[0].source.data;
t.is(output.system, '');
t.is(output.modifiedMessages[0].role, 'assistant');
t.is(output.modifiedMessages[0].content[0].type, 'image');
t.is(output.modifiedMessages[0].content[0].source.type, 'base64');
t.is(output.modifiedMessages[0].content[0].source.media_type, 'image/jpeg');
// Check if the base64 data looks reasonable
t.true(base64Data.length > 100); // Check if the data is sufficiently long
t.true(base64Regex.test(base64Data)); // Check if the data matches the base64 regex
});
test('convertMessagesToClaudeVertex unsupported type', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with unsupported type
const messages = [{ role: 'user', content: { type: 'unsupported_type' } }];
const output = await plugin.convertMessagesToClaudeVertex(messages);
t.deepEqual(output, { system: '', modifiedMessages: [{role: 'user', content: [] }] });
});
test('convertMessagesToClaudeVertex empty messages', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with empty messages
const messages = [];
const output = await plugin.convertMessagesToClaudeVertex(messages);
t.deepEqual(output, { system: '', modifiedMessages: [] });
});
test('convertMessagesToClaudeVertex system message', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with system message
const messages = [{ role: 'system', content: 'System message' }];
const output = await plugin.convertMessagesToClaudeVertex(messages);
t.deepEqual(output, { system: 'System message', modifiedMessages: [] });
});
test('convertMessagesToClaudeVertex system message with user message', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with system message followed by user message
const messages = [
{ role: 'system', content: 'System message' },
{ role: 'user', content: 'User message' }
];
const output = await plugin.convertMessagesToClaudeVertex(messages);
t.deepEqual(output, {
system: 'System message',
modifiedMessages: [{ role: 'user', content: [{ type: 'text', text: 'User message' }] }]
});
});
test('convertMessagesToClaudeVertex user message with unsupported image type', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with unsupported image type
const messages = [{ role: 'user', content: { type: 'image_url', image_url: 'https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf' } }];
const output = await plugin.convertMessagesToClaudeVertex(messages);
t.deepEqual(output, { system: '', modifiedMessages: [{role: 'user', content: [] }] });
});
test('convertMessagesToClaudeVertex user message with no content', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with no content
const messages = [{ role: 'user', content: null }];
const output = await plugin.convertMessagesToClaudeVertex(messages);
t.deepEqual(output, { system: '', modifiedMessages: [] });
});
test('convertMessagesToClaudeVertex with multi-part content array', async (t) => {
const plugin = new Claude3VertexPlugin(pathway, model);
// Test with multi-part content array
const multiPartContent = [
{
type: 'text',
text: 'Hello world'
},
{
type: 'text',
text: 'Hello2 world2'
},
{
type: 'image_url',
image_url: 'https://static.toiimg.com/thumb/msid-102827471,width-1280,height-720,resizemode-4/102827471.jpg'
}
];
const messages = [
{ role: 'system', content: 'System message' },
{ role: 'user', content: multiPartContent }
];
const output = await plugin.convertMessagesToClaudeVertex(messages);
// Verify system message is preserved
t.is(output.system, 'System message');
// Verify the user message role is preserved
t.is(output.modifiedMessages[0].role, 'user');
// Verify the content array has the correct number of items
// We expect 3 items: 2 text items and 1 image item
t.is(output.modifiedMessages[0].content.length, 3);
// Verify the text content items
t.is(output.modifiedMessages[0].content[0].type, 'text');
t.is(output.modifiedMessages[0].content[0].text, 'Hello world');
t.is(output.modifiedMessages[0].content[1].type, 'text');
t.is(output.modifiedMessages[0].content[1].text, 'Hello2 world2');
// Verify the image content item
t.is(output.modifiedMessages[0].content[2].type, 'image');
t.is(output.modifiedMessages[0].content[2].source.type, 'base64');
t.is(output.modifiedMessages[0].content[2].source.media_type, 'image/jpeg');
// Check if the base64 data looks reasonable
const base64Data = output.modifiedMessages[0].content[2].source.data;
const base64Regex = /^[A-Za-z0-9+/]+={0,2}$/;
t.true(base64Data.length > 100); // Check if the data is sufficiently long
t.true(base64Regex.test(base64Data)); // Check if the data matches the base64 regex
});