UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

527 lines 21.9 kB
import { randomUUID } from 'crypto'; import { getModelFromAgent, getModelProvider } from '../model_providers/model_provider.js'; import { MessageHistory } from '../utils/message_history.js'; import { handleToolCall } from '../utils/tool_execution_manager.js'; import { processToolResult } from '../utils/tool_result_processor.js'; import { verifyOutput, setEnsembleRequestFunction } from '../utils/verification.js'; import { setEnsembleRequestFunction as setImageToTextFunction } from '../utils/image_to_text.js'; import { waitWhilePaused } from '../utils/pause_controller.js'; import { emitEvent } from '../utils/event_controller.js'; import { createTraceContext } from '../utils/trace_context.js'; import { convertToThinkingMessage, convertToOutputMessage, convertToFunctionCall, convertToFunctionCallOutput, } from '../utils/message_converter.js'; import { truncateLargeValues } from '../utils/truncate_utils.js'; const MAX_ERROR_ATTEMPTS = 5; setEnsembleRequestFunction(ensembleRequest); setImageToTextFunction(ensembleRequest); export async function* ensembleRequest(messages, agent = {}) { const conversationHistory = agent?.historyThread || messages; if (agent.instructions) { const alreadyHasInstructions = conversationHistory.some(msg => { return (msg.type === 'message' && msg.role === 'system' && 'content' in msg && typeof msg.content === 'string' && msg.content.trim() === agent.instructions.trim()); }); if (!alreadyHasInstructions) { const instructionsMessage = { type: 'message', role: 'system', content: agent.instructions, id: randomUUID(), }; conversationHistory.unshift(instructionsMessage); yield { type: 'response_output', message: instructionsMessage, request_id: randomUUID(), }; agent.instructions = undefined; } } const history = new MessageHistory(conversationHistory, { compactToolCalls: true, preserveSystemMessages: true, compactionThreshold: 0.7, }); const trace = createTraceContext(agent, 'chat'); let totalToolCalls = 0; let toolCallRounds = 0; let errorRounds = 0; let turnStatus = 'completed'; let turnEndReason = 'completed'; let turnError; const maxToolCalls = agent?.maxToolCalls ?? 200; const maxRounds = agent?.maxToolCallRoundsPerTurn ?? Infinity; let hasToolCalls = false; let hasError = false; let lastMessageContent = ''; const modelHistory = []; await trace.emitTurnStart({ input_messages: conversationHistory, }); try { do { hasToolCalls = false; hasError = false; let currentRoundRequestId; const currentRoundMessages = []; const currentRoundErrors = []; let currentRoundToolCalls = 0; let currentRoundRequestDuration; let currentRoundDurationWithTools; let currentRoundRequestCost; const model = await getModelFromAgent(agent, 'reasoning_mini', modelHistory); modelHistory.push(model); const stream = executeRound(model, agent, history, totalToolCalls, maxToolCalls, trace); for await (const event of stream) { yield event; switch (event.type) { case 'agent_start': { currentRoundRequestId = event.request_id; break; } case 'message_complete': { const messageEvent = event; if (messageEvent.content) { lastMessageContent = messageEvent.content; currentRoundMessages.push(messageEvent.content); } break; } case 'tool_start': { const toolEvent = event; if (toolEvent.tool_call) { const toolName = toolEvent.tool_call.function.name; currentRoundToolCalls += 1; await trace.emitToolStart(event.request_id, toolEvent.tool_call.id, { tool_name: toolName, arguments: toolEvent.tool_call.function.arguments, arguments_formatted: toolEvent.tool_call.function.arguments_formatted, }); if (toolName !== 'task_complete' && toolName !== 'task_fatal_error') { hasToolCalls = true; } } ++totalToolCalls; break; } case 'tool_done': { const toolEvent = event; if (toolEvent.tool_call) { await trace.emitToolDone(event.request_id, toolEvent.tool_call.id, { tool_name: toolEvent.tool_call.function.name, call_id: toolEvent.result?.call_id, output: toolEvent.result?.output, error: toolEvent.result?.error, }); } break; } case 'agent_done': { const agentDoneEvent = event; currentRoundRequestDuration = agentDoneEvent.request_duration; currentRoundDurationWithTools = agentDoneEvent.duration_with_tools; currentRoundRequestCost = agentDoneEvent.request_cost; break; } case 'error': { hasError = true; const errorEvent = event; if (errorEvent.error) { currentRoundErrors.push(String(errorEvent.error)); } break; } } } if (hasToolCalls) { ++toolCallRounds; if (agent.modelSettings?.tool_choice) { delete agent.modelSettings.tool_choice; } } if (hasError) { ++errorRounds; } const willRetryForError = hasError && errorRounds < MAX_ERROR_ATTEMPTS; const willContinueForTools = hasToolCalls && toolCallRounds < maxRounds && totalToolCalls < maxToolCalls; const willContinue = willRetryForError || willContinueForTools; let requestStatus = 'completed'; if (hasError) { requestStatus = willContinue ? 'error_retrying' : 'error'; } else if (hasToolCalls) { requestStatus = willContinue ? 'waiting_for_followup_request' : 'tool_limit_reached'; } if (currentRoundRequestId) { await trace.emitRequestEnd(currentRoundRequestId, { status: requestStatus, will_continue: willContinue, tool_calls: currentRoundToolCalls, final_response: currentRoundMessages.length > 0 ? currentRoundMessages.join('\n') : undefined, errors: currentRoundErrors.length > 0 ? currentRoundErrors : undefined, request_duration_ms: currentRoundRequestDuration, duration_with_tools_ms: currentRoundDurationWithTools, request_cost: currentRoundRequestCost, }); } } while ((hasError && errorRounds < MAX_ERROR_ATTEMPTS) || (hasToolCalls && toolCallRounds < maxRounds && totalToolCalls < maxToolCalls)); if (hasToolCalls && toolCallRounds >= maxRounds) { console.log('[ensembleRequest] Tool call rounds limit reached'); turnEndReason = 'max_tool_call_rounds_reached'; } else if (hasToolCalls && totalToolCalls >= maxToolCalls) { console.log('[ensembleRequest] Total tool calls limit reached'); turnEndReason = 'max_tool_calls_reached'; } else if (hasError && errorRounds >= MAX_ERROR_ATTEMPTS) { turnStatus = 'error'; turnEndReason = 'max_error_attempts_reached'; } if (agent?.verifier && lastMessageContent) { const verificationResult = await performVerification(agent, lastMessageContent, await history.getMessages()); if (verificationResult) { for await (const event of verificationResult) { yield event; } } } } catch (err) { const error = err; turnStatus = 'error'; turnEndReason = 'exception'; turnError = error.message || 'Unknown error'; yield { type: 'error', error: error.message || 'Unknown error', code: error.code, details: error.details, recoverable: error.recoverable, timestamp: new Date().toISOString(), }; } finally { await trace.emitTurnEnd(turnStatus, turnEndReason, { error: turnError, tool_call_rounds: toolCallRounds, total_tool_calls: totalToolCalls, error_rounds: errorRounds, }); yield { type: 'stream_end', timestamp: new Date().toISOString(), }; } } async function* executeRound(model, agent, history, currentToolCalls, maxToolCalls, trace) { const requestId = randomUUID(); const startTime = Date.now(); let totalCost = 0; let messages = await history.getMessages(model); const agentStartEvent = { type: 'agent_start', request_id: requestId, input: 'content' in messages[0] && typeof messages[0].content === 'string' ? messages[0].content : undefined, timestamp: new Date().toISOString(), agent: { agent_id: agent.agent_id, name: agent.name, parent_id: agent.parent_id, model: agent.model || model, modelClass: agent.modelClass, cwd: agent.cwd, modelScores: agent.modelScores, disabledModels: agent.disabledModels, tags: agent.tags, }, }; yield agentStartEvent; await emitEvent(agentStartEvent, agent, model); if (agent.onRequest) { [agent, messages] = await agent.onRequest(agent, messages); } await waitWhilePaused(100, agent.abortSignal); const provider = getModelProvider(model); await trace.emitRequestStart(requestId, { agent_id: agent.agent_id, provider: provider.provider_id, model, payload: { messages, model_settings: agent.modelSettings, tool_names: agent.tools?.map(tool => tool.definition.function.name) || [], }, }); const stream = 'createResponseStreamWithRetry' in provider ? provider.createResponseStreamWithRetry(messages, model, agent, requestId) : provider.createResponseStream(messages, model, agent, requestId); const toolPromises = []; const toolCallFormattedArgs = new Map(); const toolEventBuffer = []; agent.onToolEvent = async (event) => { toolEventBuffer.push(event); }; for await (let event of stream) { event = { ...event, request_id: requestId }; if (event.type === 'tool_start') { const toolEvent = event; if (toolEvent.tool_call) { const toolCall = toolEvent.tool_call; let argumentsFormatted; try { const tool = agent.tools?.find(t => t.definition.function.name === toolCall.function.name); if (tool && 'definition' in tool && tool.definition.function.parameters.properties) { const parsedArgs = JSON.parse(toolCall.function.arguments || '{}'); if (typeof parsedArgs === 'object' && parsedArgs !== null && !Array.isArray(parsedArgs)) { const paramNames = Object.keys(tool.definition.function.parameters.properties); const orderedArgs = {}; for (const param of paramNames) { if (param in parsedArgs) { orderedArgs[param] = parsedArgs[param]; } } argumentsFormatted = JSON.stringify(orderedArgs, null, 2); } } } catch (error) { console.debug('Failed to format tool arguments:', error); } if (argumentsFormatted) { toolCallFormattedArgs.set(toolCall.id, argumentsFormatted); } const modifiedEvent = { ...event, tool_call: { ...toolCall, function: { ...toolCall.function, arguments_formatted: argumentsFormatted, }, }, }; event = modifiedEvent; } } yield event; await emitEvent(event, agent, model); switch (event.type) { case 'cost_update': { const costEvent = event; if (costEvent.usage?.cost) { totalCost += costEvent.usage.cost; } break; } case 'message_complete': { const messageEvent = event; if (messageEvent.thinking_content || (!messageEvent.content && messageEvent.message_id)) { const thinkingMessage = convertToThinkingMessage(messageEvent, model); if (agent.onThinking) { await agent.onThinking(thinkingMessage); } history.add(thinkingMessage); yield { type: 'response_output', message: thinkingMessage, request_id: requestId, }; } if (messageEvent.content) { const contentMessage = convertToOutputMessage(messageEvent, model, 'completed'); if (agent.onResponse) { await agent.onResponse(contentMessage); } history.add(contentMessage); yield { type: 'response_output', message: contentMessage, request_id: requestId, }; } break; } case 'tool_start': { const toolEvent = event; if (!toolEvent.tool_call) { break; } const remainingCalls = maxToolCalls - currentToolCalls; if (remainingCalls <= 0) { console.warn(`Tool call limit reached (${maxToolCalls}). Skipping tool calls.`); break; } const toolCall = toolEvent.tool_call; const functionCall = convertToFunctionCall(toolCall, model, 'completed'); toolPromises.push(processToolCall(toolCall, agent)); history.add(functionCall); yield { type: 'response_output', message: functionCall, request_id: requestId, }; break; } case 'error': { console.error('[executeRound] Error event:', truncateLargeValues(event.error)); break; } } } const request_duration = Date.now() - startTime; const toolResults = await Promise.all(toolPromises); for (const toolResult of toolResults) { const toolName = toolResult.toolCall.function.name; const isSpecialTool = toolName === 'task_complete' || toolName === 'task_fatal_error'; const formattedArgs = toolCallFormattedArgs.get(toolResult.toolCall.id); const toolCallWithFormattedArgs = formattedArgs ? { ...toolResult.toolCall, function: { ...toolResult.toolCall.function, arguments_formatted: formattedArgs, }, } : toolResult.toolCall; const toolDoneEvent = { type: 'tool_done', request_id: requestId, tool_call: toolCallWithFormattedArgs, result: { call_id: toolResult.call_id || toolResult.id, output: toolResult.output, error: toolResult.error, }, }; yield toolDoneEvent; await emitEvent(toolDoneEvent, agent, model); if (!isSpecialTool) { const functionOutput = convertToFunctionCallOutput(toolResult, model, 'completed'); history.add(functionOutput); yield { type: 'response_output', message: functionOutput, request_id: requestId, }; } } const duration_with_tools = Date.now() - startTime; const agentDoneEvent = { type: 'agent_done', request_id: requestId, request_cost: totalCost > 0 ? totalCost : undefined, request_duration, duration_with_tools, timestamp: new Date().toISOString(), }; yield agentDoneEvent; await emitEvent(agentDoneEvent, agent, model); for (const bufferedEvent of toolEventBuffer) { yield { ...bufferedEvent, request_id: requestId }; } } async function* performVerification(agent, output, messages, attempt = 0) { if (!agent.verifier) return; const maxAttempts = agent.maxVerificationAttempts || 2; const verification = await verifyOutput(agent.verifier, output, messages); if (verification.status === 'pass') { yield { type: 'message_delta', content: '\n\n✓ Output verified', }; return; } if (attempt < maxAttempts - 1) { yield { type: 'message_delta', content: `\n\n⚠️ Verification failed: ${verification.reason}\n\nRetrying...`, }; const retryMessages = [ ...messages, { type: 'message', role: 'assistant', content: output, status: 'completed', }, { type: 'message', role: 'developer', content: `Verification failed: ${verification.reason}\n\nPlease correct your response.`, }, ]; const retryAgent = { ...agent, verifier: undefined, historyThread: retryMessages, }; const retryStream = ensembleRequest(retryMessages, retryAgent); let retryOutput = ''; for await (const event of retryStream) { yield event; if (event.type === 'message_complete' && 'content' in event) { retryOutput = event.content; } } if (retryOutput) { yield* performVerification(agent, retryOutput, messages, attempt + 1); } } else { yield { type: 'message_delta', content: `\n\n❌ Verification failed after ${maxAttempts} attempts: ${verification.reason}`, }; } } async function processToolCall(toolCall, agent) { if (agent.onToolCall) { await agent.onToolCall(toolCall); } try { if (!agent.tools) { throw new Error('No tools available for agent'); } const tool = agent.tools.find(t => t.definition.function.name === toolCall.function.name); if (!tool || !('function' in tool)) { throw new Error(`Tool ${toolCall.function.name} not found`); } const rawResult = await handleToolCall(toolCall, tool, agent); const processedResult = await processToolResult(toolCall, rawResult, agent, tool.allowSummary); const toolCallResult = { toolCall, id: toolCall.id, call_id: toolCall.call_id || toolCall.id, output: processedResult, }; if (agent.onToolResult) { await agent.onToolResult(toolCallResult); } return toolCallResult; } catch (error) { const errorOutput = error instanceof Error ? `Tool execution failed: ${error.message}` : `Tool execution failed: ${String(error)}`; const toolCallResult = { toolCall, id: toolCall.id, call_id: toolCall.call_id || toolCall.id, error: errorOutput, }; if (agent.onToolError) { await agent.onToolError(toolCallResult); } return toolCallResult; } } export function mergeHistoryThread(mainHistory, thread, startIndex) { const newMessages = thread.slice(startIndex); mainHistory.push(...newMessages); } //# sourceMappingURL=ensemble_request.js.map