UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

200 lines 8.21 kB
import { BaseModelProvider } from './base_provider.js'; import { v4 as uuidv4 } from 'uuid'; import { costTracker } from '../index.js'; import { hasEventHandler } from '../utils/event_controller.js'; export const testProviderConfig = { streamingDelay: 50, shouldError: false, errorMessage: 'Simulated error from test provider', simulateRateLimit: false, fixedResponse: undefined, fixedThinking: undefined, simulateToolCall: false, toolName: 'web_search', toolArguments: { query: 'test query' }, tokenUsage: { inputTokens: 100, outputTokens: 200, }, chunkSize: 5, }; export function resetTestProviderConfig() { testProviderConfig.streamingDelay = 50; testProviderConfig.shouldError = false; testProviderConfig.errorMessage = 'Simulated error from test provider'; testProviderConfig.simulateRateLimit = false; testProviderConfig.fixedResponse = undefined; testProviderConfig.fixedThinking = undefined; testProviderConfig.simulateToolCall = false; testProviderConfig.toolName = 'web_search'; testProviderConfig.toolArguments = { query: 'test query' }; testProviderConfig.tokenUsage = { inputTokens: 100, outputTokens: 200, }; testProviderConfig.chunkSize = 5; } const sleep = (ms) => new Promise(resolve => setTimeout(resolve, ms)); export class TestProvider extends BaseModelProvider { config; constructor(config = testProviderConfig) { super('test'); this.config = config; } async *createResponseStream(messages, model, agent) { console.log(`[TestProvider] Creating response stream for model: ${model}`); const lastUserMessage = messages.filter(m => 'role' in m && m.role === 'user').pop(); const userMessageContent = lastUserMessage && 'content' in lastUserMessage ? typeof lastUserMessage.content === 'string' ? lastUserMessage.content : JSON.stringify(lastUserMessage.content) : ''; const inputTokenCount = this.config.tokenUsage?.inputTokens || Math.max(50, Math.ceil(userMessageContent.length / 4)); let response; if (this.config.simulateRateLimit) { const rateLimitError = '429 Too Many Requests: The server is currently processing too many requests. Please try again later.'; yield { type: 'error', error: rateLimitError, }; return; } if (this.config.shouldError) { yield { type: 'error', error: this.config.errorMessage || 'Simulated error from test provider', }; return; } if (this.config.fixedResponse) { response = this.config.fixedResponse; } else { response = this.generateResponse(userMessageContent); } const messageId = uuidv4(); yield { type: 'message_start', message_id: messageId, content: '', }; if (this.config.fixedThinking) { yield { type: 'message_delta', message_id: messageId, content: '', thinking_content: this.config.fixedThinking, thinking_signature: '(Simulated thinking)', }; await sleep(this.config.streamingDelay || 50); } if (this.config.simulateToolCall && agent) { const { getToolsFromAgent } = await import('../utils/agent.js'); const currentTools = getToolsFromAgent(agent); if (currentTools) { const toolArray = await currentTools; if (toolArray.length > 0) { const availableTool = toolArray.find(tool => this.config.toolName ? tool.definition.function.name === this.config.toolName : true); if (availableTool) { const toolCall = { id: uuidv4(), type: 'function', function: { name: availableTool.definition.function.name, arguments: JSON.stringify(this.config.toolArguments || { query: userMessageContent.slice(0, 50), }), }, }; yield { type: 'tool_start', tool_call: toolCall, }; await sleep(this.config.streamingDelay || 50); response = `I've used the ${toolCall.function.name} tool to help answer your question.\n\n${response}`; } } } } const chunkSize = this.config.chunkSize || 5; let position = 0; while (position < response.length) { const chunk = response.slice(position, position + chunkSize); position += chunkSize; yield { type: 'message_delta', message_id: messageId, content: chunk, order: position / chunkSize, }; await sleep(this.config.streamingDelay || 50); } yield { type: 'message_complete', message_id: messageId, content: response, }; const outputTokenCount = this.config.tokenUsage?.outputTokens || Math.ceil(response.length / 4); const calculatedUsage = costTracker.addUsage({ model, input_tokens: inputTokenCount, output_tokens: outputTokenCount, }); if (!hasEventHandler()) { yield { type: 'cost_update', usage: { ...calculatedUsage, total_tokens: inputTokenCount + outputTokenCount, }, }; } } generateResponse(input) { const lowercaseInput = input.toLowerCase(); if (lowercaseInput.includes('hello') || lowercaseInput.includes('hi')) { return "Hello! I'm a test AI model. How can I help you today?"; } else if (lowercaseInput.includes('help')) { return "I'm here to help! What do you need assistance with?"; } else if (lowercaseInput.includes('error') || lowercaseInput.includes('problem')) { return "I understand you're experiencing an issue. Let me help troubleshoot the problem."; } else if (lowercaseInput.includes('json') || lowercaseInput.includes('person')) { return '{"name": "John Doe", "age": 30}'; } else if (lowercaseInput.includes('test')) { return 'This is a test response. The test provider is working correctly!'; } else if (lowercaseInput.includes('weather')) { return 'The weather is sunny and 72°F.'; } else if (lowercaseInput.includes('?')) { return "That's an interesting question. As a test model, I'm designed to provide simulated responses for testing purposes."; } else { return `I've received your message: "${input.slice(0, 50)}${input.length > 50 ? '...' : ''}". This is a simulated response from the test provider.`; } } async createEmbedding(input, model, opts) { const generateVector = (text) => { const dimension = opts?.dimension || 384; const vector = new Array(dimension); for (let i = 0; i < dimension; i++) { const charCode = text.charCodeAt(i % text.length) || 0; const value = Math.sin(charCode * (i + 1) * 0.01) * 0.5 + 0.5; vector[i] = value; } return vector; }; if (Array.isArray(input)) { return input.map(text => generateVector(text)); } else { return generateVector(input); } } } export const testProvider = new TestProvider(); //# sourceMappingURL=test_provider.js.map