UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

405 lines 15.7 kB
import { randomUUID } from 'crypto'; import { getModelFromAgent, getModelProvider } from '../model_providers/model_provider.js'; import { MessageHistory } from '../utils/message_history.js'; import { handleToolCall } from '../utils/tool_execution_manager.js'; import { processToolResult } from '../utils/tool_result_processor.js'; import { verifyOutput } from '../utils/verification.js'; import { waitWhilePaused } from '../utils/pause_controller.js'; import { emitEvent } from '../utils/event_controller.js'; import { convertToThinkingMessage, convertToOutputMessage, convertToFunctionCall, convertToFunctionCallOutput, } from '../utils/message_converter.js'; const MAX_ERROR_ATTEMPTS = 5; export async function* ensembleRequest(messages, agent = {}) { const conversationHistory = agent?.historyThread || messages; if (conversationHistory.length === 0) { conversationHistory.push({ type: 'message', role: 'user', content: 'Begin.', }); } if (agent.instructions) { const firstMsg = conversationHistory[0]; const alreadyHasInstructions = firstMsg && 'content' in firstMsg && typeof firstMsg.content === 'string' && firstMsg.content.trim() === agent.instructions.trim(); if (!alreadyHasInstructions) { conversationHistory.unshift({ type: 'message', role: 'system', content: agent.instructions, }); } } const history = new MessageHistory(conversationHistory, { compactToolCalls: true, preserveSystemMessages: true, compactionThreshold: 0.7, }); try { let totalToolCalls = 0; let toolCallRounds = 0; let errorRounds = 0; const maxToolCalls = agent?.maxToolCalls ?? 200; const maxRounds = agent?.maxToolCallRoundsPerTurn ?? Infinity; let hasToolCalls = false; let hasError = false; let lastMessageContent = ''; const modelHistory = []; do { hasToolCalls = false; hasError = false; const model = await getModelFromAgent(agent, 'reasoning_mini', modelHistory); modelHistory.push(model); const stream = executeRound(model, agent, history, totalToolCalls, maxToolCalls); for await (const event of stream) { yield event; switch (event.type) { case 'message_complete': { const messageEvent = event; if (messageEvent.content) { lastMessageContent = messageEvent.content; } break; } case 'tool_start': { hasToolCalls = true; ++totalToolCalls; break; } case 'error': { hasError = true; break; } } } if (hasToolCalls) { ++toolCallRounds; if (agent.modelSettings?.tool_choice) { delete agent.modelSettings.tool_choice; } } if (hasError) { ++errorRounds; } } while ((hasError && errorRounds < MAX_ERROR_ATTEMPTS) || (hasToolCalls && toolCallRounds < maxRounds && totalToolCalls < maxToolCalls)); if (hasToolCalls && toolCallRounds >= maxRounds) { console.log('[ensembleRequest] Tool call rounds limit reached'); } else if (hasToolCalls && totalToolCalls >= maxToolCalls) { console.log('[ensembleRequest] Total tool calls limit reached'); } if (agent?.verifier && lastMessageContent) { const verificationResult = await performVerification(agent, lastMessageContent, await history.getMessages()); if (verificationResult) { for await (const event of verificationResult) { yield event; } } } } catch (err) { const error = err; yield { type: 'error', error: error.message || 'Unknown error', code: error.code, details: error.details, recoverable: error.recoverable, timestamp: new Date().toISOString(), }; } finally { yield { type: 'stream_end', timestamp: new Date().toISOString(), }; } } async function* executeRound(model, agent, history, currentToolCalls, maxToolCalls) { const requestId = randomUUID(); const startTime = Date.now(); let totalCost = 0; let messages = await history.getMessages(model); const agentStartEvent = { type: 'agent_start', request_id: requestId, input: 'content' in messages[0] && typeof messages[0].content === 'string' ? messages[0].content : undefined, timestamp: new Date().toISOString(), agent: { agent_id: agent.agent_id, name: agent.name, parent_id: agent.parent_id, model: agent.model || model, modelClass: agent.modelClass, cwd: agent.cwd, modelScores: agent.modelScores, disabledModels: agent.disabledModels, }, }; yield agentStartEvent; await emitEvent(agentStartEvent, agent, model); if (agent.onRequest) { [agent, messages] = await agent.onRequest(agent, messages); } await waitWhilePaused(100, agent.abortSignal); const provider = getModelProvider(model); const stream = 'createResponseStreamWithRetry' in provider ? provider.createResponseStreamWithRetry(messages, model, agent) : provider.createResponseStream(messages, model, agent); const toolPromises = []; const toolCallFormattedArgs = new Map(); const toolEventBuffer = []; agent.onToolEvent = async (event) => { toolEventBuffer.push(event); }; for await (let event of stream) { if (event.type === 'tool_start') { const toolEvent = event; if (toolEvent.tool_call) { const toolCall = toolEvent.tool_call; let argumentsFormatted; try { const tool = agent.tools?.find(t => t.definition.function.name === toolCall.function.name); if (tool && 'definition' in tool && tool.definition.function.parameters.properties) { const parsedArgs = JSON.parse(toolCall.function.arguments || '{}'); if (typeof parsedArgs === 'object' && parsedArgs !== null && !Array.isArray(parsedArgs)) { const paramNames = Object.keys(tool.definition.function.parameters.properties); const orderedArgs = {}; for (const param of paramNames) { if (param in parsedArgs) { orderedArgs[param] = parsedArgs[param]; } } argumentsFormatted = JSON.stringify(orderedArgs, null, 2); } } } catch (error) { console.debug('Failed to format tool arguments:', error); } if (argumentsFormatted) { toolCallFormattedArgs.set(toolCall.id, argumentsFormatted); } const modifiedEvent = { ...event, tool_call: { ...toolCall, function: { ...toolCall.function, arguments_formatted: argumentsFormatted, }, }, }; event = modifiedEvent; } } yield event; await emitEvent(event, agent, model); switch (event.type) { case 'cost_update': { const costEvent = event; if (costEvent.usage?.cost) { totalCost += costEvent.usage.cost; } break; } case 'message_complete': { const messageEvent = event; if (messageEvent.thinking_content || (!messageEvent.content && messageEvent.message_id)) { const thinkingMessage = convertToThinkingMessage(messageEvent, model); if (agent.onThinking) { await agent.onThinking(thinkingMessage); } history.add(thinkingMessage); yield { type: 'response_output', message: thinkingMessage, }; } if (messageEvent.content) { const contentMessage = convertToOutputMessage(messageEvent, model, 'completed'); if (agent.onResponse) { await agent.onResponse(contentMessage); } history.add(contentMessage); yield { type: 'response_output', message: contentMessage, }; } break; } case 'tool_start': { const toolEvent = event; if (!toolEvent.tool_call) { break; } const remainingCalls = maxToolCalls - currentToolCalls; if (remainingCalls <= 0) { console.warn(`Tool call limit reached (${maxToolCalls}). Skipping tool calls.`); break; } const toolCall = toolEvent.tool_call; const functionCall = convertToFunctionCall(toolCall, model, 'completed'); toolPromises.push(processToolCall(toolCall, agent)); history.add(functionCall); yield { type: 'response_output', message: functionCall, }; break; } case 'error': { console.error('[executeRound] Error event:', event.error); break; } } } const request_duration = Date.now() - startTime; const toolResults = await Promise.all(toolPromises); for (const toolResult of toolResults) { const formattedArgs = toolCallFormattedArgs.get(toolResult.toolCall.id); const toolCallWithFormattedArgs = formattedArgs ? { ...toolResult.toolCall, function: { ...toolResult.toolCall.function, arguments_formatted: formattedArgs, }, } : toolResult.toolCall; const toolDoneEvent = { type: 'tool_done', tool_call: toolCallWithFormattedArgs, result: { call_id: toolResult.call_id || toolResult.id, output: toolResult.output, error: toolResult.error, }, }; yield toolDoneEvent; await emitEvent(toolDoneEvent, agent, model); const functionOutput = convertToFunctionCallOutput(toolResult, model, 'completed'); history.add(functionOutput); yield { type: 'response_output', message: functionOutput, }; } const duration_with_tools = Date.now() - startTime; await emitEvent({ type: 'agent_done', request_id: requestId, request_cost: totalCost > 0 ? totalCost : undefined, request_duration, duration_with_tools, timestamp: new Date().toISOString(), }, agent, model); for (const bufferedEvent of toolEventBuffer) { yield bufferedEvent; } } async function* performVerification(agent, output, messages, attempt = 0) { if (!agent.verifier) return; const maxAttempts = agent.maxVerificationAttempts || 2; const verification = await verifyOutput(agent.verifier, output, messages); if (verification.status === 'pass') { yield { type: 'message_delta', content: '\n\n✓ Output verified', }; return; } if (attempt < maxAttempts - 1) { yield { type: 'message_delta', content: `\n\n⚠️ Verification failed: ${verification.reason}\n\nRetrying...`, }; const retryMessages = [ ...messages, { type: 'message', role: 'assistant', content: output, status: 'completed', }, { type: 'message', role: 'developer', content: `Verification failed: ${verification.reason}\n\nPlease correct your response.`, }, ]; const retryAgent = { ...agent, verifier: undefined, historyThread: retryMessages, }; const retryStream = ensembleRequest(retryMessages, retryAgent); let retryOutput = ''; for await (const event of retryStream) { yield event; if (event.type === 'message_complete' && 'content' in event) { retryOutput = event.content; } } if (retryOutput) { yield* performVerification(agent, retryOutput, messages, attempt + 1); } } else { yield { type: 'message_delta', content: `\n\n❌ Verification failed after ${maxAttempts} attempts: ${verification.reason}`, }; } } async function processToolCall(toolCall, agent) { if (agent.onToolCall) { await agent.onToolCall(toolCall); } try { if (!agent.tools) { throw new Error('No tools available for agent'); } const tool = agent.tools.find(t => t.definition.function.name === toolCall.function.name); if (!tool || !('function' in tool)) { throw new Error(`Tool ${toolCall.function.name} not found`); } const rawResult = await handleToolCall(toolCall, tool, agent); const processedResult = await processToolResult(toolCall, rawResult, agent); const toolCallResult = { toolCall, id: toolCall.id, call_id: toolCall.call_id || toolCall.id, output: processedResult, }; if (agent.onToolResult) { await agent.onToolResult(toolCallResult); } return toolCallResult; } catch (error) { const errorOutput = error instanceof Error ? `Tool execution failed: ${error.message}` : `Tool execution failed: ${String(error)}`; const toolCallResult = { toolCall, id: toolCall.id, call_id: toolCall.call_id || toolCall.id, error: errorOutput, }; if (agent.onToolError) { await agent.onToolError(toolCallResult); } return toolCallResult; } } export function mergeHistoryThread(mainHistory, thread, startIndex) { const newMessages = thread.slice(startIndex); mainHistory.push(...newMessages); } //# sourceMappingURL=ensemble_request.js.map