UNPKG

@aj-archipelago/cortex

Version:

Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.

456 lines (357 loc) 20.3 kB
// ModelPlugin.test.js import test from 'ava'; import ModelPlugin from '../server/plugins/modelPlugin.js'; import { encode } from '../lib/encodeCache.js'; import { mockPathwayResolverString } from './mocks.js'; const { config, pathway, modelName, model } = mockPathwayResolverString; const modelPlugin = new ModelPlugin(pathway, model); const generateMessage = (role, content) => ({ role, content }); const generateStructuredMessage = (role, content) => ({ role, content: [{ type: 'text', text: content }] }); test('truncateMessagesToTargetLength: should not modify messages if already within target length', (t) => { const messages = [ generateMessage('user', 'Hello, how are you?'), generateMessage('assistant', 'I am doing well, thank you!'), ]; const targetTokenLength = modelPlugin.countMessagesTokens(messages); const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength); t.deepEqual(result, messages); }); test('truncateMessagesToTargetLength: should prioritize final user message', (t) => { const messages = [ generateMessage('system', 'System message'), generateMessage('user', 'First user message'), generateMessage('assistant', 'Assistant response'), generateMessage('user', 'Final important question that should be preserved'), ]; // Set target length to only fit the final user message plus the minimum safety margin const finalUserMsg = messages[messages.length - 1]; const targetTokenLength = modelPlugin.countMessagesTokens([finalUserMsg]) * 1.1; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength); t.is(result.length, 1, 'Should only keep final user message'); t.is(result[0].role, 'user', 'Should be a user message'); t.is(result[0].content, finalUserMsg.content, 'Should preserve final user message content'); }); test('truncateMessagesToTargetLength: should prioritize final user message with tight constraints', (t) => { const messages = [ generateMessage('system', 'System message content that is very long and exceeds the target token length'), generateMessage('user', 'Hello, how are you?'), generateMessage('assistant', 'I am fine, thank you.'), generateMessage('user', 'Final user message'), ]; // Very tight token constraint const targetTokenLength = 15; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength); // Should prioritize final user message t.is(result.length, 1, 'Should keep only the final user message with tight constraints'); t.is(result[0].role, 'user', 'Should keep the user message'); t.is(result[0].content.length <= messages[3].content.length, true, 'User message may be truncated'); }); test('truncateMessagesToTargetLength: should add truncation markers to shortened messages', (t) => { // Create a very long message that will definitely be truncated const longContent = 'a'.repeat(1000); const messages = [ generateMessage('system', 'System message: ' + longContent), generateMessage('user', 'Final user message: ' + longContent), ]; // Set a target token length that will force heavy truncation const targetTokenLength = 20; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength); // Verify truncation markers are added const expectedMarker = "[...]"; // Check if at least one message has the truncation marker const hasMarker = result.some(msg => msg.content.includes(expectedMarker)); t.true(hasMarker, 'At least one message should have truncation marker'); // Verify individual messages result.forEach(msg => { // Only verify messages that were actually truncated if (msg.content.length < 1000) { t.true(msg.content.includes(expectedMarker), `Truncated ${msg.role} message should include truncation marker`); } }); }); test('truncateMessagesToTargetLength: should not add truncation markers to messages that fit completely', (t) => { const messages = [ generateMessage('system', 'Short system message'), generateMessage('user', 'Short user message'), generateMessage('assistant', 'Short assistant message'), generateMessage('user', 'Another short user message'), ]; // Set a target token length that allows all messages to fit const targetTokenLength = encode(modelPlugin.messagesToChatML(messages, false)).length; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength); // Verify no truncation markers are added const expectedMarker = "[...]"; // None of the messages should have the truncation marker const hasMarker = result.some(msg => msg.content.includes(expectedMarker)); t.false(hasMarker, 'No message should have a truncation marker when all fit completely'); // Verify content is unchanged result.forEach((msg, index) => { t.is(msg.content, messages[index].content, `${msg.role} message content should be unchanged`); }); }); test('truncateMessagesToTargetLength: should handle extreme token constraints with markers', (t) => { // Create a very long message that will definitely be truncated const longContent = 'a'.repeat(1000); const messages = [ generateMessage('system', 'System message: ' + longContent), generateMessage('user', 'Final user message: ' + longContent), ]; // Extremely tight token constraint const targetTokenLength = 30; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength); // Verify result t.true(result.length > 0, 'Should have at least one message'); // The kept message should have the truncation marker const expectedMarker = "[...]"; t.true(result[0].content.includes(expectedMarker), 'Extremely truncated message should include truncation marker'); }); test('truncateMessagesToTargetLength: should maintain message order', (t) => { const messages = [ generateMessage('system', 'System message'), generateMessage('user', 'First user message'), generateMessage('assistant', 'Assistant response'), generateMessage('user', 'Second user message'), generateMessage('assistant', 'Second assistant response'), generateMessage('user', 'Final user message'), ]; // Set target length to fit all messages const targetTokenLength = encode(modelPlugin.messagesToChatML(messages, false)).length; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength); t.deepEqual(result.map(m => m.role), messages.map(m => m.role), 'Message order should be preserved'); }); test('truncateMessagesToTargetLength: should return messages with [...] if target length is 0', (t) => { const messages = [ generateMessage('user', 'Hello, how are you?'), generateMessage('assistant', 'I am doing well, thank you!'), ]; const result = modelPlugin.truncateMessagesToTargetLength(messages, null, 0); // Should return all messages but with [...] content t.is(result.length, messages.length, 'Should return all messages'); // Each message should be truncated to just the marker result.forEach(msg => { t.is(msg.content, '[...]', 'Message content should be just the truncation marker'); }); }); test('truncateMessagesToTargetLength: should handle structured messages with maxMessageTokenLength=0', (t) => { const messages = [ generateStructuredMessage('user', 'Hello, how are you?'), generateStructuredMessage('assistant', 'I am doing well, thank you!'), ]; const result = modelPlugin.truncateMessagesToTargetLength(messages, null, 0); // Should return all messages but with [...] content t.is(result.length, messages.length, 'Should return all structured messages'); // Each message should be truncated to just a single content item with the marker result.forEach(msg => { t.true(Array.isArray(msg.content), 'Content should still be an array'); t.is(msg.content.length, 1, 'Should have exactly one content item'); t.is(msg.content[0].type, 'text', 'Content item should be of type text'); t.is(msg.content[0].text, '[...]', 'Content text should be just the truncation marker'); }); }); // New tests for maxMessageTokenLength test('truncateMessagesToTargetLength: should respect maxMessageTokenLength constraint', (t) => { // Create messages with different lengths const longContent = 'a'.repeat(1000); const messages = [ generateMessage('user', 'Short first message'), generateMessage('assistant', longContent), generateMessage('user', 'Short final message'), ]; // Set a target that would fit all messages normally const targetTokenLength = modelPlugin.countMessagesTokens(messages) + 100; // Calculate tokens in the assistant message const assistantMsgTokens = modelPlugin.countMessagesTokens([messages[1]]); // Set maxMessageTokenLength to be less than the assistant message length const maxMessageTokenLength = Math.floor(assistantMsgTokens * 0.3); const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength, maxMessageTokenLength); // All messages should be present t.is(result.length, 3, 'All messages should be preserved'); // Only the long message should be truncated t.is(result[0].content, messages[0].content, 'First message should be unchanged'); t.is(result[2].content, messages[2].content, 'Last message should be unchanged'); // The assistant message should be truncated t.true(result[1].content.length < longContent.length, 'Long message should be truncated'); t.true(result[1].content.includes('[...]'), 'Truncated message should have marker'); // Calculate tokens in the truncated message const truncatedMsgTokens = modelPlugin.countMessagesTokens([result[1]]); // Allow small buffer for truncation marker t.true(truncatedMsgTokens <= maxMessageTokenLength + 10, `Truncated message (${truncatedMsgTokens} tokens) should not exceed maxMessageTokenLength (${maxMessageTokenLength} tokens) by more than buffer`); }); test('truncateMessagesToTargetLength: should handle very small maxMessageTokenLength', (t) => { const messages = [ generateMessage('system', 'System message that will definitely need to be truncated to fit the maxMessageTokenLength'), generateMessage('user', 'This is a user message that will need to be heavily truncated to fit the maxMessageTokenLength'), ]; // Set a large target token length const targetTokenLength = 1000; // But set a very small maxMessageTokenLength const maxMessageTokenLength = 5; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength, maxMessageTokenLength); // All messages should be present but truncated t.is(result.length, 2, 'Both messages should be present'); // Both messages should be truncated to fit the maxMessageTokenLength result.forEach(msg => { const msgTokens = modelPlugin.safeGetEncodedLength(msg.content); t.true(msgTokens <= maxMessageTokenLength + 5, `Message (${msgTokens} tokens) should not exceed maxMessageTokenLength (${maxMessageTokenLength}) by more than buffer`); t.true(msg.content.includes('[...]'), 'Truncated message should have marker'); }); }); test('truncateMessagesToTargetLength: should handle both constraints together', (t) => { const longContent = 'a'.repeat(500); const messages = [ generateMessage('system', 'System: ' + longContent), generateMessage('user', 'User: ' + longContent), generateMessage('assistant', 'Assistant: ' + longContent), generateMessage('user', 'Final: ' + longContent), ]; // Set a moderate target token length const targetTokenLength = 300; // And a moderate maxMessageTokenLength const maxMessageTokenLength = 100; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength, maxMessageTokenLength); // We should have some messages, but not necessarily all t.true(result.length > 0 && result.length <= messages.length, 'Should have some messages'); // Total token count should be below target const totalTokens = modelPlugin.countMessagesTokens(result); t.true(totalTokens <= targetTokenLength, `Total tokens (${totalTokens}) should not exceed target length (${targetTokenLength})`); // Each message should respect maxMessageTokenLength result.forEach(msg => { const msgTokens = modelPlugin.countMessagesTokens([msg]); t.true(msgTokens <= maxMessageTokenLength + 10, `Message (${msgTokens} tokens) should not exceed maxMessageTokenLength (${maxMessageTokenLength}) by more than buffer`); }); }); test('truncateMessagesToTargetLength: maxMessageTokenLength should not affect unchanged messages', (t) => { const messages = [ generateMessage('system', 'Short system message'), generateMessage('user', 'Short user message'), ]; // Calculate tokens in each message const systemMsgTokens = modelPlugin.countMessagesTokens([messages[0]]); const userMsgTokens = modelPlugin.countMessagesTokens([messages[1]]); // Set maxMessageTokenLength above individual message sizes but below their sum const maxMessageTokenLength = Math.max(systemMsgTokens, userMsgTokens) + 10; // Set target length to fit all messages const targetTokenLength = modelPlugin.countMessagesTokens(messages) + 20; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength, maxMessageTokenLength); // All messages should be unchanged t.is(result.length, 2, 'Both messages should be present'); t.is(result[0].content, messages[0].content, 'First message should be unchanged'); t.is(result[1].content, messages[1].content, 'Second message should be unchanged'); // No truncation markers const hasMarker = result.some(msg => msg.content.includes('[...]')); t.false(hasMarker, 'No message should have truncation marker'); }); test('truncateMessagesToTargetLength: should truncate long messages with maxMessageTokenLength', t => { const longText = 'A'.repeat(6000); const messages = [ generateMessage('user', longText), generateMessage('assistant', 'Response'), generateMessage('user', 'Short message') ]; const shortMsgTokens = modelPlugin.countMessagesTokens([{ role: 'user', content: 'Short message' }]); const maxMessageTokenLength = shortMsgTokens * 2; // Just enough to force truncation of long messages // Large target to ensure only maxMessageTokenLength constraint is active const targetTokenLength = 10000; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength, maxMessageTokenLength); // Check that long message was truncated const longMsgTokens = modelPlugin.countMessagesTokens([result[0]]); t.true(longMsgTokens <= maxMessageTokenLength + 10, `Long message (${longMsgTokens} tokens) should be truncated to near maxMessageTokenLength (${maxMessageTokenLength})`); t.true(result[0].content.includes('[...]'), 'Truncated message should have truncation marker'); // Short messages should be unchanged t.is(result[1].content, 'Response'); t.is(result[2].content, 'Short message'); }); test('truncateMessagesToTargetLength: should not truncate image content with maxMessageTokenLength', t => { const longText = 'A'.repeat(6000); const imageContent = { type: 'image_url', url: 'image.jpg' }; const longTextContent = { type: 'text', text: longText }; const messages = [ generateMessage('user', [imageContent, longTextContent]), generateMessage('assistant', 'I see an image') ]; // Calculate tokens for image + some text const imageTokens = 100; // Estimate from countMessagesTokens const maxMessageTokenLength = imageTokens + 50; // Enough for image but not all text // Large target to ensure only maxMessageTokenLength constraint is active const targetTokenLength = 10000; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength, maxMessageTokenLength); // Image should be preserved t.deepEqual(result[0].content[0], imageContent, 'Image content should be preserved'); // Text should be truncated t.true(result[0].content[1].text.length < longText.length, 'Text content should be truncated'); t.true(result[0].content[1].text.includes('[...]'), 'Truncated text should have marker'); // Check overall message length const msgTokens = modelPlugin.countMessagesTokens([result[0]]); t.true(msgTokens <= maxMessageTokenLength + 10, `Message tokens (${msgTokens}) should not exceed maxMessageTokenLength (${maxMessageTokenLength}) by more than buffer`); }); test('truncateMessagesToTargetLength: should truncate array content with maxMessageTokenLength', t => { const longText1 = 'A'.repeat(3000); const longText2 = 'B'.repeat(3000); const longTextContent1 = { type: 'text', text: longText1 }; const longTextContent2 = { type: 'text', text: longText2 }; const messages = [ generateMessage('user', [longTextContent1, longTextContent2]), generateMessage('assistant', 'Response') ]; // Set a moderate maxMessageTokenLength const maxMessageTokenLength = 200; // Large target to ensure only maxMessageTokenLength constraint is active const targetTokenLength = 10000; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength, maxMessageTokenLength); // Check that message was truncated const msgTokens = modelPlugin.countMessagesTokens([result[0]]); t.true(msgTokens <= maxMessageTokenLength + 10, `Message tokens (${msgTokens}) should not exceed maxMessageTokenLength (${maxMessageTokenLength}) by more than buffer`); // At least one of the text items should be truncated const hasMarker = result[0].content.some(item => typeof item === 'string' && item.includes('[...]') || item.type === 'text' && item.text.includes('[...]')); t.true(hasMarker, 'At least one content item should have truncation marker'); }); test('truncateMessagesToTargetLength: should handle mixed message types with maxMessageTokenLength', t => { const longText = 'A'.repeat(10000); const shortText = 'Short message'; const imageContent = { type: 'image_url', url: 'image.jpg' }; const longTextContent = { type: 'text', text: longText }; const shortTextContent = { type: 'text', text: shortText }; const messages = [ generateMessage('user', shortText), generateMessage('assistant', longText), generateMessage('user', [shortTextContent, imageContent, longTextContent]), generateMessage('system', longText) ]; // Calculate reasonable maxMessageTokenLength const shortMsgTokens = modelPlugin.countMessagesTokens([{ role: 'user', content: shortText }]); const maxMessageTokenLength = 200; // Force truncation of long messages // Large target to ensure only maxMessageTokenLength constraint is active const targetTokenLength = 10000; const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength, maxMessageTokenLength); // Short message should be unchanged t.is(result[0].content, shortText, 'Short message should be unchanged'); // Long text messages should be truncated t.true(result[1].content.length < longText.length, 'Long text message should be truncated'); t.true(result[1].content.includes('[...]'), 'Truncated message should have marker'); t.true(result[3].content.length < longText.length, 'Long system message should be truncated'); // Check multimodal message t.deepEqual(result[2].content[1], imageContent, 'Image should be preserved'); if (typeof result[2].content[1] === 'string') { t.true(result[2].content[1].length < longText.length, 'Text in multimodal message should be truncated'); } else if (result[2].content[1] && result[2].content[1].type === 'text') { t.true(result[2].content[1].text.length < longText.length, 'Text in multimodal message should be truncated'); } // All messages should respect maxMessageTokenLength result.forEach((msg, i) => { const msgTokens = modelPlugin.countMessagesTokens([msg]); t.true(msgTokens <= maxMessageTokenLength + 10, `Message ${i} tokens (${msgTokens}) should not exceed maxMessageTokenLength (${maxMessageTokenLength}) by more than buffer`); }); });