pyb-ts
Version:
PYB-CLI - Minimal AI Agent with multi-model support and CLI interface
471 lines (468 loc) • 14.2 kB
JavaScript
import {
messagePairValidForBinaryFeedback,
shouldUseBinaryFeedback
} from "@components/binary-feedback/utils";
import {
formatSystemPromptWithContext,
queryLLM
} from "@services/claude";
import { emitReminderEvent } from "@services/systemReminder";
import { all } from "@utils/generators";
import { logError } from "@utils/log";
import {
debug as debugLogger,
markPhase,
getCurrentRequest,
logUserFriendly
} from "./utils/debugLogger.js";
import {
createAssistantMessage,
createProgressMessage,
createToolResultStopMessage,
createUserMessage,
INTERRUPT_MESSAGE,
INTERRUPT_MESSAGE_FOR_TOOL_USE,
normalizeMessagesForAPI
} from "@utils/messages";
import { BashTool } from "@tools/BashTool/BashTool";
import { globalMemoryHook } from "@utils/memoryRecorder";
import { getCwd } from "./utils/state.js";
import { checkAutoCompact } from "./utils/autoCompactCore.js";
const MAX_TOOL_USE_CONCURRENCY = 10;
async function queryWithBinaryFeedback(toolUseContext, getAssistantResponse, getBinaryFeedbackResponse) {
if (process.env.USER_TYPE !== "ant" || !getBinaryFeedbackResponse || !await shouldUseBinaryFeedback()) {
const assistantMessage = await getAssistantResponse();
if (toolUseContext.abortController.signal.aborted) {
return { message: null, shouldSkipPermissionCheck: false };
}
return { message: assistantMessage, shouldSkipPermissionCheck: false };
}
const [m1, m2] = await Promise.all([
getAssistantResponse(),
getAssistantResponse()
]);
if (toolUseContext.abortController.signal.aborted) {
return { message: null, shouldSkipPermissionCheck: false };
}
if (m2.isApiErrorMessage) {
return { message: m1, shouldSkipPermissionCheck: false };
}
if (m1.isApiErrorMessage) {
return { message: m2, shouldSkipPermissionCheck: false };
}
if (!messagePairValidForBinaryFeedback(m1, m2)) {
return { message: m1, shouldSkipPermissionCheck: false };
}
return await getBinaryFeedbackResponse(m1, m2);
}
async function* query(messages, systemPrompt, context, canUseTool, toolUseContext, getBinaryFeedbackResponse) {
const currentRequest = getCurrentRequest();
markPhase("QUERY_INIT");
const { messages: processedMessages, wasCompacted } = await checkAutoCompact(
messages,
toolUseContext
);
if (wasCompacted) {
messages = processedMessages;
}
markPhase("SYSTEM_PROMPT_BUILD");
const { systemPrompt: fullSystemPrompt, reminders } = formatSystemPromptWithContext(systemPrompt, context, toolUseContext.agentId);
emitReminderEvent("session:startup", {
agentId: toolUseContext.agentId,
messages: messages.length,
timestamp: Date.now()
});
if (reminders && messages.length > 0) {
for (let i = messages.length - 1; i >= 0; i--) {
const msg = messages[i];
if (msg?.type === "user") {
const lastUserMessage = msg;
messages[i] = {
...lastUserMessage,
message: {
...lastUserMessage.message,
content: typeof lastUserMessage.message.content === "string" ? reminders + lastUserMessage.message.content : [
...Array.isArray(lastUserMessage.message.content) ? lastUserMessage.message.content : [],
{ type: "text", text: reminders }
]
}
};
break;
}
}
}
markPhase("LLM_PREPARATION");
function getAssistantResponse() {
return queryLLM(
normalizeMessagesForAPI(messages),
fullSystemPrompt,
toolUseContext.options.maxThinkingTokens,
toolUseContext.options.tools,
toolUseContext.abortController.signal,
{
safeMode: toolUseContext.options.safeMode ?? false,
model: toolUseContext.options.model || "main",
prependCLISysprompt: true,
toolUseContext
}
);
}
const result = await queryWithBinaryFeedback(
toolUseContext,
getAssistantResponse,
getBinaryFeedbackResponse
);
if (toolUseContext.abortController.signal.aborted) {
yield createAssistantMessage(INTERRUPT_MESSAGE);
return;
}
if (result.message === null) {
yield createAssistantMessage(INTERRUPT_MESSAGE);
return;
}
const assistantMessage = result.message;
const shouldSkipPermissionCheck = result.shouldSkipPermissionCheck;
yield assistantMessage;
const toolUseMessages = assistantMessage.message.content.filter(
(_) => _.type === "tool_use"
);
if (!toolUseMessages.length) {
return;
}
const toolResults = [];
const canRunConcurrently = toolUseMessages.every(
(msg) => toolUseContext.options.tools.find((t) => t.name === msg.name)?.isReadOnly()
);
if (canRunConcurrently) {
for await (const message of runToolsConcurrently(
toolUseMessages,
assistantMessage,
canUseTool,
toolUseContext,
shouldSkipPermissionCheck
)) {
yield message;
if (message.type === "user") {
toolResults.push(message);
}
}
} else {
for await (const message of runToolsSerially(
toolUseMessages,
assistantMessage,
canUseTool,
toolUseContext,
shouldSkipPermissionCheck
)) {
yield message;
if (message.type === "user") {
toolResults.push(message);
}
}
}
if (toolUseContext.abortController.signal.aborted) {
yield createAssistantMessage(INTERRUPT_MESSAGE_FOR_TOOL_USE);
return;
}
const orderedToolResults = toolResults.sort((a, b) => {
const aIndex = toolUseMessages.findIndex(
(tu) => tu.id === a.message.content[0].id
);
const bIndex = toolUseMessages.findIndex(
(tu) => tu.id === b.message.content[0].id
);
return aIndex - bIndex;
});
try {
yield* await query(
[...messages, assistantMessage, ...orderedToolResults],
systemPrompt,
context,
canUseTool,
toolUseContext,
getBinaryFeedbackResponse
);
} catch (error) {
throw error;
}
}
async function* runToolsConcurrently(toolUseMessages, assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck) {
yield* all(
toolUseMessages.map(
(toolUse) => runToolUse(
toolUse,
new Set(toolUseMessages.map((_) => _.id)),
assistantMessage,
canUseTool,
toolUseContext,
shouldSkipPermissionCheck
)
),
MAX_TOOL_USE_CONCURRENCY
);
}
async function* runToolsSerially(toolUseMessages, assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck) {
for (const toolUse of toolUseMessages) {
yield* runToolUse(
toolUse,
new Set(toolUseMessages.map((_) => _.id)),
assistantMessage,
canUseTool,
toolUseContext,
shouldSkipPermissionCheck
);
}
}
async function* runToolUse(toolUse, siblingToolUseIDs, assistantMessage, canUseTool, toolUseContext, shouldSkipPermissionCheck) {
const currentRequest = getCurrentRequest();
debugLogger.flow("TOOL_USE_START", {
toolName: toolUse.name,
toolUseID: toolUse.id,
inputSize: JSON.stringify(toolUse.input).length,
siblingToolCount: siblingToolUseIDs.size,
shouldSkipPermissionCheck: !!shouldSkipPermissionCheck,
requestId: currentRequest?.id
});
logUserFriendly(
"TOOL_EXECUTION",
{
toolName: toolUse.name,
action: "Starting",
target: toolUse.input ? Object.keys(toolUse.input).join(", ") : ""
},
currentRequest?.id
);
const toolName = toolUse.name;
const tool = toolUseContext.options.tools.find((t) => t.name === toolName);
if (!tool) {
debugLogger.error("TOOL_NOT_FOUND", {
requestedTool: toolName,
availableTools: toolUseContext.options.tools.map((t) => t.name),
toolUseID: toolUse.id,
requestId: currentRequest?.id
});
yield createUserMessage([
{
type: "tool_result",
content: `Error: No such tool available: ${toolName}`,
is_error: true,
tool_use_id: toolUse.id
}
]);
return;
}
const toolInput = toolUse.input;
debugLogger.flow("TOOL_VALIDATION_START", {
toolName: tool.name,
toolUseID: toolUse.id,
inputKeys: Object.keys(toolInput),
requestId: currentRequest?.id
});
try {
if (toolUseContext.abortController.signal.aborted) {
debugLogger.flow("TOOL_USE_CANCELLED_BEFORE_START", {
toolName: tool.name,
toolUseID: toolUse.id,
abortReason: "AbortController signal",
requestId: currentRequest?.id
});
const message = createUserMessage([
createToolResultStopMessage(toolUse.id)
]);
yield message;
return;
}
let hasProgressMessages = false;
for await (const message of checkPermissionsAndCallTool(
tool,
toolUse.id,
siblingToolUseIDs,
toolInput,
toolUseContext,
canUseTool,
assistantMessage,
shouldSkipPermissionCheck
)) {
if (toolUseContext.abortController.signal.aborted) {
debugLogger.flow("TOOL_USE_CANCELLED_DURING_EXECUTION", {
toolName: tool.name,
toolUseID: toolUse.id,
hasProgressMessages,
abortReason: "AbortController signal during execution",
requestId: currentRequest?.id
});
if (hasProgressMessages && message.type === "progress") {
yield message;
}
const cancelMessage = createUserMessage([
createToolResultStopMessage(toolUse.id)
]);
yield cancelMessage;
return;
}
if (message.type === "progress") {
hasProgressMessages = true;
}
yield message;
}
} catch (e) {
logError(e);
const errorMessage = createUserMessage([
{
type: "tool_result",
content: `Tool execution failed: ${e instanceof Error ? e.message : String(e)}`,
is_error: true,
tool_use_id: toolUse.id
}
]);
yield errorMessage;
}
}
function normalizeToolInput(tool, input) {
switch (tool) {
case BashTool: {
const { command, timeout } = BashTool.inputSchema.parse(input);
return {
command: command.replace(`cd ${getCwd()} && `, ""),
...timeout ? { timeout } : {}
};
}
default:
return input;
}
}
async function* checkPermissionsAndCallTool(tool, toolUseID, siblingToolUseIDs, input, context, canUseTool, assistantMessage, shouldSkipPermissionCheck) {
const isValidInput = tool.inputSchema.safeParse(input);
if (!isValidInput.success) {
let errorMessage = `InputValidationError: ${isValidInput.error.message}`;
if (tool.name === "View" && Object.keys(input).length === 0) {
errorMessage = `Error: The View tool requires a 'file_path' parameter to specify which file to read. Please provide the absolute path to the file you want to view. For example: {"file_path": "/path/to/file.txt"}`;
}
yield createUserMessage([
{
type: "tool_result",
content: errorMessage,
is_error: true,
tool_use_id: toolUseID
}
]);
return;
}
const normalizedInput = normalizeToolInput(tool, input);
const isValidCall = await tool.validateInput?.(
normalizedInput,
context
);
if (isValidCall?.result === false) {
yield createUserMessage([
{
type: "tool_result",
content: isValidCall.message,
is_error: true,
tool_use_id: toolUseID
}
]);
return;
}
const permissionResult = shouldSkipPermissionCheck ? { result: true } : await canUseTool(tool, normalizedInput, context, assistantMessage);
if (permissionResult.result === false) {
yield createUserMessage([
{
type: "tool_result",
content: permissionResult.message,
is_error: true,
tool_use_id: toolUseID
}
]);
return;
}
const toolStartTime = Date.now();
try {
await globalMemoryHook.beforeToolExecution(tool, normalizedInput, context);
const generator = tool.call(normalizedInput, context);
for await (const result of generator) {
switch (result.type) {
case "result":
const executionTime = Date.now() - toolStartTime;
await globalMemoryHook.afterToolExecution(
tool,
normalizedInput,
result.data,
executionTime,
context
);
yield createUserMessage(
[
{
type: "tool_result",
content: result.resultForAssistant || String(result.data),
tool_use_id: toolUseID
}
],
{
data: result.data,
resultForAssistant: result.resultForAssistant || String(result.data)
}
);
return;
case "progress":
yield createProgressMessage(
toolUseID,
siblingToolUseIDs,
result.content,
result.normalizedMessages || [],
result.tools || []
);
break;
}
}
} catch (error) {
const executionTime = Date.now() - toolStartTime;
await globalMemoryHook.afterToolExecution(
tool,
normalizedInput,
null,
executionTime,
context,
error
);
const content = formatError(error);
logError(error);
yield createUserMessage([
{
type: "tool_result",
content,
is_error: true,
tool_use_id: toolUseID
}
]);
}
}
function formatError(error) {
if (!(error instanceof Error)) {
return String(error);
}
const parts = [error.message];
if ("stderr" in error && typeof error.stderr === "string") {
parts.push(error.stderr);
}
if ("stdout" in error && typeof error.stdout === "string") {
parts.push(error.stdout);
}
const fullMessage = parts.filter(Boolean).join("\n");
if (fullMessage.length <= 1e4) {
return fullMessage;
}
const halfLength = 5e3;
const start = fullMessage.slice(0, halfLength);
const end = fullMessage.slice(-halfLength);
return `${start}
... [${fullMessage.length - 1e4} characters truncated] ...
${end}`;
}
export {
normalizeToolInput,
query,
runToolUse
};
//# sourceMappingURL=query.js.map