UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

236 lines 9.86 kB
import { BaseModelProvider } from './base_provider.js'; import { v4 as uuidv4 } from 'uuid'; import { costTracker } from '../utils/cost_tracker.js'; import { log_llm_request, log_llm_response, log_llm_error } from '../utils/llm_logger.js'; import { hasEventHandler } from '../utils/event_controller.js'; export const testProviderConfig = { streamingDelay: 50, shouldError: false, errorMessage: 'Simulated error from test provider', simulateRateLimit: false, fixedResponse: undefined, fixedThinking: undefined, simulateToolCall: false, toolName: 'web_search', toolArguments: { query: 'test query' }, tokenUsage: { inputTokens: 100, outputTokens: 200, }, chunkSize: 5, }; export function resetTestProviderConfig() { testProviderConfig.streamingDelay = 50; testProviderConfig.shouldError = false; testProviderConfig.errorMessage = 'Simulated error from test provider'; testProviderConfig.simulateRateLimit = false; testProviderConfig.fixedResponse = undefined; testProviderConfig.fixedThinking = undefined; testProviderConfig.simulateToolCall = false; testProviderConfig.toolName = 'web_search'; testProviderConfig.toolArguments = { query: 'test query' }; testProviderConfig.tokenUsage = { inputTokens: 100, outputTokens: 200, }; testProviderConfig.chunkSize = 5; } const sleep = (ms) => new Promise(resolve => setTimeout(resolve, ms)); export class TestProvider extends BaseModelProvider { config; constructor(config = testProviderConfig) { super('test'); this.config = config; } async *createResponseStream(messages, model, agent, _requestId) { console.log(`[TestProvider] Creating response stream for model: ${model}`); const lastUserMessage = messages.filter(m => 'role' in m && m.role === 'user').pop(); const userMessageContent = lastUserMessage && 'content' in lastUserMessage ? typeof lastUserMessage.content === 'string' ? lastUserMessage.content : JSON.stringify(lastUserMessage.content) : ''; const inputTokenCount = this.config.tokenUsage?.inputTokens || Math.max(50, Math.ceil(userMessageContent.length / 4)); let response; if (this.config.simulateRateLimit) { const rateLimitError = '429 Too Many Requests: The server is currently processing too many requests. Please try again later.'; yield { type: 'error', error: rateLimitError, }; return; } if (this.config.shouldError) { yield { type: 'error', error: this.config.errorMessage || 'Simulated error from test provider', }; return; } if (this.config.fixedResponse) { response = this.config.fixedResponse; } else { response = this.generateResponse(userMessageContent); } const messageId = uuidv4(); yield { type: 'message_start', message_id: messageId, content: '', }; if (this.config.fixedThinking) { yield { type: 'message_delta', message_id: messageId, content: '', thinking_content: this.config.fixedThinking, thinking_signature: '(Simulated thinking)', }; await sleep(this.config.streamingDelay || 50); } if (this.config.simulateToolCall && agent) { const { getToolsFromAgent } = await import('../utils/agent.js'); const currentTools = getToolsFromAgent(agent); if (currentTools) { const toolArray = await currentTools; if (toolArray.length > 0) { const availableTool = toolArray.find(tool => this.config.toolName ? tool.definition.function.name === this.config.toolName : true); if (availableTool) { const toolCall = { id: uuidv4(), type: 'function', function: { name: availableTool.definition.function.name, arguments: JSON.stringify(this.config.toolArguments || { query: userMessageContent.slice(0, 50), }), }, }; yield { type: 'tool_start', tool_call: toolCall, }; await sleep(this.config.streamingDelay || 50); response = `I've used the ${toolCall.function.name} tool to help answer your question.\n\n${response}`; } } } } const chunkSize = this.config.chunkSize || 5; let position = 0; while (position < response.length) { const chunk = response.slice(position, position + chunkSize); position += chunkSize; yield { type: 'message_delta', message_id: messageId, content: chunk, order: position / chunkSize, }; await sleep(this.config.streamingDelay || 50); } yield { type: 'message_complete', message_id: messageId, content: response, }; const outputTokenCount = this.config.tokenUsage?.outputTokens || Math.ceil(response.length / 4); const calculatedUsage = costTracker.addUsage({ model, input_tokens: inputTokenCount, output_tokens: outputTokenCount, }); if (!hasEventHandler()) { yield { type: 'cost_update', usage: { ...calculatedUsage, total_tokens: inputTokenCount + outputTokenCount, }, }; } } generateResponse(input) { const lowercaseInput = input.toLowerCase(); if (lowercaseInput.includes('hello') || lowercaseInput.includes('hi')) { return "Hello! I'm a test AI model. How can I help you today?"; } else if (lowercaseInput.includes('help')) { return "I'm here to help! What do you need assistance with?"; } else if (lowercaseInput.includes('error') || lowercaseInput.includes('problem')) { return "I understand you're experiencing an issue. Let me help troubleshoot the problem."; } else if (lowercaseInput.includes('json') || lowercaseInput.includes('person')) { return '{"name": "John Doe", "age": 30}'; } else if (lowercaseInput.includes('test')) { return 'This is a test response. The test provider is working correctly!'; } else if (lowercaseInput.includes('weather')) { return 'The weather is sunny and 72°F.'; } else if (lowercaseInput.includes('?')) { return "That's an interesting question. As a test model, I'm designed to provide simulated responses for testing purposes."; } else { return `I've received your message: "${input.slice(0, 50)}${input.length > 50 ? '...' : ''}". This is a simulated response from the test provider.`; } } async createEmbedding(input, model, agent, opts) { const requestId = `req_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`; let finalRequestId = requestId; try { const requestParams = { model, input_length: Array.isArray(input) ? input.length : 1, dimension: opts?.dimensions || 384, }; const loggedRequestId = log_llm_request(agent.agent_id || 'test', 'test', model, requestParams, new Date(), requestId, agent.tags); finalRequestId = loggedRequestId; const generateVector = (text) => { const dimension = opts?.dimensions || 384; const vector = new Array(dimension); for (let i = 0; i < dimension; i++) { const charCode = text.charCodeAt(i % text.length) || 0; const value = Math.sin(charCode * (i + 1) * 0.01) * 0.5 + 0.5; vector[i] = value; } return vector; }; let result; if (Array.isArray(input)) { result = input.map(text => generateVector(text)); } else { result = generateVector(input); } const estimatedTokens = typeof input === 'string' ? Math.ceil(input.length / 4) : input.reduce((sum, text) => sum + Math.ceil(text.length / 4), 0); costTracker.addUsage({ model, input_tokens: estimatedTokens, output_tokens: 0, metadata: { dimensions: opts?.dimensions || 384, type: 'test_embedding', }, }); log_llm_response(finalRequestId, { model, dimensions: opts?.dimensions || 384, vector_count: Array.isArray(input) ? input.length : 1, estimated_tokens: estimatedTokens, }); return result; } catch (error) { log_llm_error(finalRequestId, error); throw error; } } } export const testProvider = new TestProvider(); //# sourceMappingURL=test_provider.js.map