UNPKG

pyb-ts

Version:

PYB-CLI - Minimal AI Agent with multi-model support and CLI interface

471 lines (468 loc) 14.2 kB
import { messagePairValidForBinaryFeedback, shouldUseBinaryFeedback } from "@components/binary-feedback/utils"; import { formatSystemPromptWithContext, queryLLM } from "@services/claude"; import { emitReminderEvent } from "@services/systemReminder"; import { all } from "@utils/generators"; import { logError } from "@utils/log"; import { debug as debugLogger, markPhase, getCurrentRequest, logUserFriendly } from "./utils/debugLogger.js"; import { createAssistantMessage, createProgressMessage, createToolResultStopMessage, createUserMessage, INTERRUPT_MESSAGE, INTERRUPT_MESSAGE_FOR_TOOL_USE, normalizeMessagesForAPI } from "@utils/messages"; import { BashTool } from "@tools/BashTool/BashTool"; import { globalMemoryHook } from "@utils/memoryRecorder"; import { getCwd } from "./utils/state.js"; import { checkAutoCompact } from "./utils/autoCompactCore.js"; const MAX_TOOL_USE_CONCURRENCY = 10; async function queryWithBinaryFeedback(toolUseContext, getAssistantResponse, getBinaryFeedbackResponse) { if (process.env.USER_TYPE !== "ant" || !getBinaryFeedbackResponse || !await shouldUseBinaryFeedback()) { const assistantMessage = await getAssistantResponse(); if (toolUseContext.abortController.signal.aborted) { return { message: null, shouldSkipPermissionCheck: false }; } return { message: assistantMessage, shouldSkipPermissionCheck: false }; } const [m1, m2] = await Promise.all([ getAssistantResponse(), getAssistantResponse() ]); if (toolUseContext.abortController.signal.aborted) { return { message: null, shouldSkipPermissionCheck: false }; } if (m2.isApiErrorMessage) { return { message: m1, shouldSkipPermissionCheck: false }; } if (m1.isApiErrorMessage) { return { message: m2, shouldSkipPermissionCheck: false }; } if (!messagePairValidForBinaryFeedback(m1, m2)) { return { message: m1, shouldSkipPermissionCheck: false }; } return await getBinaryFeedbackResponse(m1, m2); } async function* query(messages, systemPrompt, context, canUseTool, toolUseContext, getBinaryFeedbackResponse) { const currentRequest = getCurrentRequest(); markPhase("QUERY_INIT"); const { messages: processedMessages, wasCompacted } = await checkAutoCompact( messages, toolUseContext ); if (wasCompacted) { messages = processedMessages; } markPhase("SYSTEM_PROMPT_BUILD"); const { systemPrompt: fullSystemPrompt, reminders } = formatSystemPromptWithContext(systemPrompt, context, toolUseContext.agentId); emitReminderEvent("session:startup", { agentId: toolUseContext.agentId, messages: messages.length, timestamp: Date.now() }); if (reminders && messages.length > 0) { for (let i = messages.length - 1; i >= 0; i--) { const msg = messages[i]; if (msg?.type === "user") { const lastUserMessage = msg; messages[i] = { ...lastUserMessage, message: { ...lastUserMessage.message, content: typeof lastUserMessage.message.content === "string" ? reminders + lastUserMessage.message.content : [ ...Array.isArray(lastUserMessage.message.content) ? lastUserMessage.message.content : [], { type: "text", text: reminders } ] } }; break; } } } markPhase("LLM_PREPARATION"); function getAssistantResponse() { return queryLLM( normalizeMessagesForAPI(messages), fullSystemPrompt, toolUseContext.options.maxThinkingTokens, toolUseContext.options.tools, toolUseContext.abortController.signal, { safeMode: toolUseContext.options.safeMode ?? false, model: toolUseContext.options.model || "main", prependCLISysprompt: true, toolUseContext } ); } const result = await queryWithBinaryFeedback( toolUseContext, getAssistantResponse, getBinaryFeedbackResponse ); if (toolUseContext.abortController.signal.aborted) { yield createAssistantMessage(INTERRUPT_MESSAGE); return; } if (result.message === null) { yield createAssistantMessage(INTERRUPT_MESSAGE); return; } const assistantMessage = result.message; const shouldSkipPermissionCheck = result.shouldSkipPermissionCheck; yield assistantMessage; const toolUseMessages = assistantMessage.message.content.filter( (_) => _.type === "tool_use" ); if (!toolUseMessages.length) { return; } const toolResults = []; const canRunConcurrently = toolUseMessages.every( (msg) => toolUseContext.options.tools.find((t) => t.name === msg.name)?.isReadOnly() ); if (canRunConcurrently) { for await (const message of runToolsConcurrently( toolUseMessages, assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck )) { yield message; if (message.type === "user") { toolResults.push(message); } } } else { for await (const message of runToolsSerially( toolUseMessages, assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck )) { yield message; if (message.type === "user") { toolResults.push(message); } } } if (toolUseContext.abortController.signal.aborted) { yield createAssistantMessage(INTERRUPT_MESSAGE_FOR_TOOL_USE); return; } const orderedToolResults = toolResults.sort((a, b) => { const aIndex = toolUseMessages.findIndex( (tu) => tu.id === a.message.content[0].id ); const bIndex = toolUseMessages.findIndex( (tu) => tu.id === b.message.content[0].id ); return aIndex - bIndex; }); try { yield* await query( [...messages, assistantMessage, ...orderedToolResults], systemPrompt, context, canUseTool, toolUseContext, getBinaryFeedbackResponse ); } catch (error) { throw error; } } async function* runToolsConcurrently(toolUseMessages, assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck) { yield* all( toolUseMessages.map( (toolUse) => runToolUse( toolUse, new Set(toolUseMessages.map((_) => _.id)), assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck ) ), MAX_TOOL_USE_CONCURRENCY ); } async function* runToolsSerially(toolUseMessages, assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck) { for (const toolUse of toolUseMessages) { yield* runToolUse( toolUse, new Set(toolUseMessages.map((_) => _.id)), assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck ); } } async function* runToolUse(toolUse, siblingToolUseIDs, assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck) { const currentRequest = getCurrentRequest(); debugLogger.flow("TOOL_USE_START", { toolName: toolUse.name, toolUseID: toolUse.id, inputSize: JSON.stringify(toolUse.input).length, siblingToolCount: siblingToolUseIDs.size, shouldSkipPermissionCheck: !!shouldSkipPermissionCheck, requestId: currentRequest?.id }); logUserFriendly( "TOOL_EXECUTION", { toolName: toolUse.name, action: "Starting", target: toolUse.input ? Object.keys(toolUse.input).join(", ") : "" }, currentRequest?.id ); const toolName = toolUse.name; const tool = toolUseContext.options.tools.find((t) => t.name === toolName); if (!tool) { debugLogger.error("TOOL_NOT_FOUND", { requestedTool: toolName, availableTools: toolUseContext.options.tools.map((t) => t.name), toolUseID: toolUse.id, requestId: currentRequest?.id }); yield createUserMessage([ { type: "tool_result", content: `Error: No such tool available: ${toolName}`, is_error: true, tool_use_id: toolUse.id } ]); return; } const toolInput = toolUse.input; debugLogger.flow("TOOL_VALIDATION_START", { toolName: tool.name, toolUseID: toolUse.id, inputKeys: Object.keys(toolInput), requestId: currentRequest?.id }); try { if (toolUseContext.abortController.signal.aborted) { debugLogger.flow("TOOL_USE_CANCELLED_BEFORE_START", { toolName: tool.name, toolUseID: toolUse.id, abortReason: "AbortController signal", requestId: currentRequest?.id }); const message = createUserMessage([ createToolResultStopMessage(toolUse.id) ]); yield message; return; } let hasProgressMessages = false; for await (const message of checkPermissionsAndCallTool( tool, toolUse.id, siblingToolUseIDs, toolInput, toolUseContext, canUseTool, assistantMessage, shouldSkipPermissionCheck )) { if (toolUseContext.abortController.signal.aborted) { debugLogger.flow("TOOL_USE_CANCELLED_DURING_EXECUTION", { toolName: tool.name, toolUseID: toolUse.id, hasProgressMessages, abortReason: "AbortController signal during execution", requestId: currentRequest?.id }); if (hasProgressMessages && message.type === "progress") { yield message; } const cancelMessage = createUserMessage([ createToolResultStopMessage(toolUse.id) ]); yield cancelMessage; return; } if (message.type === "progress") { hasProgressMessages = true; } yield message; } } catch (e) { logError(e); const errorMessage = createUserMessage([ { type: "tool_result", content: `Tool execution failed: ${e instanceof Error ? e.message : String(e)}`, is_error: true, tool_use_id: toolUse.id } ]); yield errorMessage; } } function normalizeToolInput(tool, input) { switch (tool) { case BashTool: { const { command, timeout } = BashTool.inputSchema.parse(input); return { command: command.replace(`cd ${getCwd()} && `, ""), ...timeout ? { timeout } : {} }; } default: return input; } } async function* checkPermissionsAndCallTool(tool, toolUseID, siblingToolUseIDs, input, context, canUseTool, assistantMessage, shouldSkipPermissionCheck) { const isValidInput = tool.inputSchema.safeParse(input); if (!isValidInput.success) { let errorMessage = `InputValidationError: ${isValidInput.error.message}`; if (tool.name === "View" && Object.keys(input).length === 0) { errorMessage = `Error: The View tool requires a 'file_path' parameter to specify which file to read. Please provide the absolute path to the file you want to view. For example: {"file_path": "/path/to/file.txt"}`; } yield createUserMessage([ { type: "tool_result", content: errorMessage, is_error: true, tool_use_id: toolUseID } ]); return; } const normalizedInput = normalizeToolInput(tool, input); const isValidCall = await tool.validateInput?.( normalizedInput, context ); if (isValidCall?.result === false) { yield createUserMessage([ { type: "tool_result", content: isValidCall.message, is_error: true, tool_use_id: toolUseID } ]); return; } const permissionResult = shouldSkipPermissionCheck ? { result: true } : await canUseTool(tool, normalizedInput, context, assistantMessage); if (permissionResult.result === false) { yield createUserMessage([ { type: "tool_result", content: permissionResult.message, is_error: true, tool_use_id: toolUseID } ]); return; } const toolStartTime = Date.now(); try { await globalMemoryHook.beforeToolExecution(tool, normalizedInput, context); const generator = tool.call(normalizedInput, context); for await (const result of generator) { switch (result.type) { case "result": const executionTime = Date.now() - toolStartTime; await globalMemoryHook.afterToolExecution( tool, normalizedInput, result.data, executionTime, context ); yield createUserMessage( [ { type: "tool_result", content: result.resultForAssistant || String(result.data), tool_use_id: toolUseID } ], { data: result.data, resultForAssistant: result.resultForAssistant || String(result.data) } ); return; case "progress": yield createProgressMessage( toolUseID, siblingToolUseIDs, result.content, result.normalizedMessages || [], result.tools || [] ); break; } } } catch (error) { const executionTime = Date.now() - toolStartTime; await globalMemoryHook.afterToolExecution( tool, normalizedInput, null, executionTime, context, error ); const content = formatError(error); logError(error); yield createUserMessage([ { type: "tool_result", content, is_error: true, tool_use_id: toolUseID } ]); } } function formatError(error) { if (!(error instanceof Error)) { return String(error); } const parts = [error.message]; if ("stderr" in error && typeof error.stderr === "string") { parts.push(error.stderr); } if ("stdout" in error && typeof error.stdout === "string") { parts.push(error.stdout); } const fullMessage = parts.filter(Boolean).join("\n"); if (fullMessage.length <= 1e4) { return fullMessage; } const halfLength = 5e3; const start = fullMessage.slice(0, halfLength); const end = fullMessage.slice(-halfLength); return `${start} ... [${fullMessage.length - 1e4} characters truncated] ... ${end}`; } export { normalizeToolInput, query, runToolUse }; //# sourceMappingURL=query.js.map