@aj-archipelago/cortex
Version:
Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.
197 lines (168 loc) • 5.51 kB
JavaScript
import test from 'ava';
import serverFactory from '../index.js';
import { PathwayResolver } from '../server/pathwayResolver.js';
import OpenAIChatPlugin from '../server/plugins/openAiChatPlugin.js';
import GeminiChatPlugin from '../server/plugins/geminiChatPlugin.js';
import Gemini15ChatPlugin from '../server/plugins/gemini15ChatPlugin.js';
import Claude3VertexPlugin from '../server/plugins/claude3VertexPlugin.js';
import { config } from '../config.js';
let testServer;
test.before(async () => {
process.env.CORTEX_ENABLE_REST = 'true';
const { server, startServer } = await serverFactory();
startServer && await startServer();
testServer = server;
});
test.after.always('cleanup', async () => {
if (testServer) {
await testServer.stop();
}
});
// Helper function to create a PathwayResolver with a specific plugin
function createResolverWithPlugin(pluginClass, modelName = 'test-model') {
// Map plugin classes to their corresponding model types
const pluginToModelType = {
OpenAIChatPlugin: 'OPENAI-VISION',
GeminiChatPlugin: 'GEMINI-VISION',
Gemini15ChatPlugin: 'GEMINI-1.5-VISION',
Claude3VertexPlugin: 'CLAUDE-3-VERTEX'
};
const modelType = pluginToModelType[pluginClass.name];
if (!modelType) {
throw new Error(`Unknown plugin class: ${pluginClass.name}`);
}
const pathway = {
name: 'test-pathway',
model: modelName,
prompt: 'test prompt'
};
const model = {
name: modelName,
type: modelType
};
const resolver = new PathwayResolver({
config,
pathway,
args: {},
endpoints: { [modelName]: model }
});
resolver.modelExecutor.plugin = new pluginClass(pathway, model);
return resolver;
}
// Test OpenAI Chat Plugin Streaming
test('OpenAI Chat Plugin - processStreamEvent handles content chunks correctly', async t => {
const resolver = createResolverWithPlugin(OpenAIChatPlugin);
const plugin = resolver.modelExecutor.plugin;
// Test regular content chunk
const contentEvent = {
data: JSON.stringify({
id: 'test-id',
choices: [{
delta: { content: 'test content' },
finish_reason: null
}]
})
};
let progress = plugin.processStreamEvent(contentEvent, {});
t.is(progress.data, contentEvent.data);
t.falsy(progress.progress);
// Test stream end
const endEvent = {
data: JSON.stringify({
id: 'test-id',
choices: [{
delta: {},
finish_reason: 'stop'
}]
})
};
progress = plugin.processStreamEvent(endEvent, {});
t.is(progress.progress, 1);
});
// Test Gemini Chat Plugin Streaming
test('Gemini Chat Plugin - processStreamEvent handles content chunks correctly', async t => {
const resolver = createResolverWithPlugin(GeminiChatPlugin);
const plugin = resolver.modelExecutor.plugin;
// Test regular content chunk
const contentEvent = {
data: JSON.stringify({
candidates: [{
content: {
parts: [{ text: 'test content' }]
},
finishReason: null
}]
})
};
let progress = plugin.processStreamEvent(contentEvent, {});
t.truthy(progress.data, 'Should have data');
const parsedData = JSON.parse(progress.data);
t.truthy(parsedData.candidates, 'Should have candidates array');
t.truthy(parsedData.candidates[0].content, 'Should have content object');
t.truthy(parsedData.candidates[0].content.parts, 'Should have parts array');
t.is(parsedData.candidates[0].content.parts[0].text, 'test content', 'Content should match');
t.falsy(progress.progress);
// Test stream end with STOP
const endEvent = {
data: JSON.stringify({
candidates: [{
content: {
parts: [{ text: '' }]
},
finishReason: 'STOP'
}]
})
};
progress = plugin.processStreamEvent(endEvent, {});
t.is(progress.progress, 1);
});
// Test Gemini 15 Chat Plugin Streaming
test('Gemini 15 Chat Plugin - processStreamEvent handles safety blocks', async t => {
const resolver = createResolverWithPlugin(Gemini15ChatPlugin);
const plugin = resolver.modelExecutor.plugin;
// Test safety block
const safetyEvent = {
data: JSON.stringify({
candidates: [{
safetyRatings: [{ blocked: true }]
}]
})
};
const progress = plugin.processStreamEvent(safetyEvent, {});
t.true(progress.data.includes('Response blocked'));
t.is(progress.progress, 1);
});
// Test Claude 3 Vertex Plugin Streaming
test('Claude 3 Vertex Plugin - processStreamEvent handles message types', async t => {
const resolver = createResolverWithPlugin(Claude3VertexPlugin);
const plugin = resolver.modelExecutor.plugin;
// Test message start
const startEvent = {
data: JSON.stringify({
type: 'message_start',
message: { id: 'test-id' }
})
};
let progress = plugin.processStreamEvent(startEvent, {});
t.true(JSON.parse(progress.data).choices[0].delta.role === 'assistant');
// Test content block
const contentEvent = {
data: JSON.stringify({
type: 'content_block_delta',
delta: {
type: 'text_delta',
text: 'test content'
}
})
};
progress = plugin.processStreamEvent(contentEvent, {});
t.true(JSON.parse(progress.data).choices[0].delta.content === 'test content');
// Test message stop
const stopEvent = {
data: JSON.stringify({
type: 'message_stop'
})
};
progress = plugin.processStreamEvent(stopEvent, {});
t.is(progress.progress, 1);
});