@just-every/ensemble
Version:
LLM provider abstraction layer with unified streaming interface
405 lines • 15.7 kB
JavaScript
import { randomUUID } from 'crypto';
import { getModelFromAgent, getModelProvider } from '../model_providers/model_provider.js';
import { MessageHistory } from '../utils/message_history.js';
import { handleToolCall } from '../utils/tool_execution_manager.js';
import { processToolResult } from '../utils/tool_result_processor.js';
import { verifyOutput } from '../utils/verification.js';
import { waitWhilePaused } from '../utils/pause_controller.js';
import { emitEvent } from '../utils/event_controller.js';
import { convertToThinkingMessage, convertToOutputMessage, convertToFunctionCall, convertToFunctionCallOutput, } from '../utils/message_converter.js';
const MAX_ERROR_ATTEMPTS = 5;
export async function* ensembleRequest(messages, agent = {}) {
const conversationHistory = agent?.historyThread || messages;
if (conversationHistory.length === 0) {
conversationHistory.push({
type: 'message',
role: 'user',
content: 'Begin.',
});
}
if (agent.instructions) {
const firstMsg = conversationHistory[0];
const alreadyHasInstructions = firstMsg &&
'content' in firstMsg &&
typeof firstMsg.content === 'string' &&
firstMsg.content.trim() === agent.instructions.trim();
if (!alreadyHasInstructions) {
conversationHistory.unshift({
type: 'message',
role: 'system',
content: agent.instructions,
});
}
}
const history = new MessageHistory(conversationHistory, {
compactToolCalls: true,
preserveSystemMessages: true,
compactionThreshold: 0.7,
});
try {
let totalToolCalls = 0;
let toolCallRounds = 0;
let errorRounds = 0;
const maxToolCalls = agent?.maxToolCalls ?? 200;
const maxRounds = agent?.maxToolCallRoundsPerTurn ?? Infinity;
let hasToolCalls = false;
let hasError = false;
let lastMessageContent = '';
const modelHistory = [];
do {
hasToolCalls = false;
hasError = false;
const model = await getModelFromAgent(agent, 'reasoning_mini', modelHistory);
modelHistory.push(model);
const stream = executeRound(model, agent, history, totalToolCalls, maxToolCalls);
for await (const event of stream) {
yield event;
switch (event.type) {
case 'message_complete': {
const messageEvent = event;
if (messageEvent.content) {
lastMessageContent = messageEvent.content;
}
break;
}
case 'tool_start': {
hasToolCalls = true;
++totalToolCalls;
break;
}
case 'error': {
hasError = true;
break;
}
}
}
if (hasToolCalls) {
++toolCallRounds;
if (agent.modelSettings?.tool_choice) {
delete agent.modelSettings.tool_choice;
}
}
if (hasError) {
++errorRounds;
}
} while ((hasError && errorRounds < MAX_ERROR_ATTEMPTS) ||
(hasToolCalls && toolCallRounds < maxRounds && totalToolCalls < maxToolCalls));
if (hasToolCalls && toolCallRounds >= maxRounds) {
console.log('[ensembleRequest] Tool call rounds limit reached');
}
else if (hasToolCalls && totalToolCalls >= maxToolCalls) {
console.log('[ensembleRequest] Total tool calls limit reached');
}
if (agent?.verifier && lastMessageContent) {
const verificationResult = await performVerification(agent, lastMessageContent, await history.getMessages());
if (verificationResult) {
for await (const event of verificationResult) {
yield event;
}
}
}
}
catch (err) {
const error = err;
yield {
type: 'error',
error: error.message || 'Unknown error',
code: error.code,
details: error.details,
recoverable: error.recoverable,
timestamp: new Date().toISOString(),
};
}
finally {
yield {
type: 'stream_end',
timestamp: new Date().toISOString(),
};
}
}
async function* executeRound(model, agent, history, currentToolCalls, maxToolCalls) {
const requestId = randomUUID();
const startTime = Date.now();
let totalCost = 0;
let messages = await history.getMessages(model);
const agentStartEvent = {
type: 'agent_start',
request_id: requestId,
input: 'content' in messages[0] && typeof messages[0].content === 'string' ? messages[0].content : undefined,
timestamp: new Date().toISOString(),
agent: {
agent_id: agent.agent_id,
name: agent.name,
parent_id: agent.parent_id,
model: agent.model || model,
modelClass: agent.modelClass,
cwd: agent.cwd,
modelScores: agent.modelScores,
disabledModels: agent.disabledModels,
},
};
yield agentStartEvent;
await emitEvent(agentStartEvent, agent, model);
if (agent.onRequest) {
[agent, messages] = await agent.onRequest(agent, messages);
}
await waitWhilePaused(100, agent.abortSignal);
const provider = getModelProvider(model);
const stream = 'createResponseStreamWithRetry' in provider
? provider.createResponseStreamWithRetry(messages, model, agent)
: provider.createResponseStream(messages, model, agent);
const toolPromises = [];
const toolCallFormattedArgs = new Map();
const toolEventBuffer = [];
agent.onToolEvent = async (event) => {
toolEventBuffer.push(event);
};
for await (let event of stream) {
if (event.type === 'tool_start') {
const toolEvent = event;
if (toolEvent.tool_call) {
const toolCall = toolEvent.tool_call;
let argumentsFormatted;
try {
const tool = agent.tools?.find(t => t.definition.function.name === toolCall.function.name);
if (tool && 'definition' in tool && tool.definition.function.parameters.properties) {
const parsedArgs = JSON.parse(toolCall.function.arguments || '{}');
if (typeof parsedArgs === 'object' && parsedArgs !== null && !Array.isArray(parsedArgs)) {
const paramNames = Object.keys(tool.definition.function.parameters.properties);
const orderedArgs = {};
for (const param of paramNames) {
if (param in parsedArgs) {
orderedArgs[param] = parsedArgs[param];
}
}
argumentsFormatted = JSON.stringify(orderedArgs, null, 2);
}
}
}
catch (error) {
console.debug('Failed to format tool arguments:', error);
}
if (argumentsFormatted) {
toolCallFormattedArgs.set(toolCall.id, argumentsFormatted);
}
const modifiedEvent = {
...event,
tool_call: {
...toolCall,
function: {
...toolCall.function,
arguments_formatted: argumentsFormatted,
},
},
};
event = modifiedEvent;
}
}
yield event;
await emitEvent(event, agent, model);
switch (event.type) {
case 'cost_update': {
const costEvent = event;
if (costEvent.usage?.cost) {
totalCost += costEvent.usage.cost;
}
break;
}
case 'message_complete': {
const messageEvent = event;
if (messageEvent.thinking_content ||
(!messageEvent.content && messageEvent.message_id)) {
const thinkingMessage = convertToThinkingMessage(messageEvent, model);
if (agent.onThinking) {
await agent.onThinking(thinkingMessage);
}
history.add(thinkingMessage);
yield {
type: 'response_output',
message: thinkingMessage,
};
}
if (messageEvent.content) {
const contentMessage = convertToOutputMessage(messageEvent, model, 'completed');
if (agent.onResponse) {
await agent.onResponse(contentMessage);
}
history.add(contentMessage);
yield {
type: 'response_output',
message: contentMessage,
};
}
break;
}
case 'tool_start': {
const toolEvent = event;
if (!toolEvent.tool_call) {
break;
}
const remainingCalls = maxToolCalls - currentToolCalls;
if (remainingCalls <= 0) {
console.warn(`Tool call limit reached (${maxToolCalls}). Skipping tool calls.`);
break;
}
const toolCall = toolEvent.tool_call;
const functionCall = convertToFunctionCall(toolCall, model, 'completed');
toolPromises.push(processToolCall(toolCall, agent));
history.add(functionCall);
yield {
type: 'response_output',
message: functionCall,
};
break;
}
case 'error': {
console.error('[executeRound] Error event:', event.error);
break;
}
}
}
const request_duration = Date.now() - startTime;
const toolResults = await Promise.all(toolPromises);
for (const toolResult of toolResults) {
const formattedArgs = toolCallFormattedArgs.get(toolResult.toolCall.id);
const toolCallWithFormattedArgs = formattedArgs
? {
...toolResult.toolCall,
function: {
...toolResult.toolCall.function,
arguments_formatted: formattedArgs,
},
}
: toolResult.toolCall;
const toolDoneEvent = {
type: 'tool_done',
tool_call: toolCallWithFormattedArgs,
result: {
call_id: toolResult.call_id || toolResult.id,
output: toolResult.output,
error: toolResult.error,
},
};
yield toolDoneEvent;
await emitEvent(toolDoneEvent, agent, model);
const functionOutput = convertToFunctionCallOutput(toolResult, model, 'completed');
history.add(functionOutput);
yield {
type: 'response_output',
message: functionOutput,
};
}
const duration_with_tools = Date.now() - startTime;
await emitEvent({
type: 'agent_done',
request_id: requestId,
request_cost: totalCost > 0 ? totalCost : undefined,
request_duration,
duration_with_tools,
timestamp: new Date().toISOString(),
}, agent, model);
for (const bufferedEvent of toolEventBuffer) {
yield bufferedEvent;
}
}
async function* performVerification(agent, output, messages, attempt = 0) {
if (!agent.verifier)
return;
const maxAttempts = agent.maxVerificationAttempts || 2;
const verification = await verifyOutput(agent.verifier, output, messages);
if (verification.status === 'pass') {
yield {
type: 'message_delta',
content: '\n\n✓ Output verified',
};
return;
}
if (attempt < maxAttempts - 1) {
yield {
type: 'message_delta',
content: `\n\n⚠️ Verification failed: ${verification.reason}\n\nRetrying...`,
};
const retryMessages = [
...messages,
{
type: 'message',
role: 'assistant',
content: output,
status: 'completed',
},
{
type: 'message',
role: 'developer',
content: `Verification failed: ${verification.reason}\n\nPlease correct your response.`,
},
];
const retryAgent = {
...agent,
verifier: undefined,
historyThread: retryMessages,
};
const retryStream = ensembleRequest(retryMessages, retryAgent);
let retryOutput = '';
for await (const event of retryStream) {
yield event;
if (event.type === 'message_complete' && 'content' in event) {
retryOutput = event.content;
}
}
if (retryOutput) {
yield* performVerification(agent, retryOutput, messages, attempt + 1);
}
}
else {
yield {
type: 'message_delta',
content: `\n\n❌ Verification failed after ${maxAttempts} attempts: ${verification.reason}`,
};
}
}
async function processToolCall(toolCall, agent) {
if (agent.onToolCall) {
await agent.onToolCall(toolCall);
}
try {
if (!agent.tools) {
throw new Error('No tools available for agent');
}
const tool = agent.tools.find(t => t.definition.function.name === toolCall.function.name);
if (!tool || !('function' in tool)) {
throw new Error(`Tool ${toolCall.function.name} not found`);
}
const rawResult = await handleToolCall(toolCall, tool, agent);
const processedResult = await processToolResult(toolCall, rawResult, agent);
const toolCallResult = {
toolCall,
id: toolCall.id,
call_id: toolCall.call_id || toolCall.id,
output: processedResult,
};
if (agent.onToolResult) {
await agent.onToolResult(toolCallResult);
}
return toolCallResult;
}
catch (error) {
const errorOutput = error instanceof Error
? `Tool execution failed: ${error.message}`
: `Tool execution failed: ${String(error)}`;
const toolCallResult = {
toolCall,
id: toolCall.id,
call_id: toolCall.call_id || toolCall.id,
error: errorOutput,
};
if (agent.onToolError) {
await agent.onToolError(toolCallResult);
}
return toolCallResult;
}
}
export function mergeHistoryThread(mainHistory, thread, startIndex) {
const newMessages = thread.slice(startIndex);
mainHistory.push(...newMessages);
}
//# sourceMappingURL=ensemble_request.js.map