UNPKG

pyb-ts

Version:

PYB-CLI - Minimal AI Agent with multi-model support and CLI interface

591 lines (590 loc) 20.2 kB
import { Box, Newline, Static } from "ink"; import ProjectOnboarding, { markProjectOnboardingComplete } from "@components/ProjectOnboarding"; import { CostThresholdDialog } from "@components/CostThresholdDialog"; import * as React from "react"; import { useEffect, useMemo, useRef, useState, useCallback } from "react"; import { Logo } from "@components/Logo"; import { Message } from "@components/Message"; import { MessageResponse } from "@components/MessageResponse"; import { MessageSelector } from "@components/MessageSelector"; import { PermissionRequest } from "@components/permissions/PermissionRequest"; import PromptInput from "@components/PromptInput"; import { Spinner } from "@components/Spinner"; import { getSystemPrompt } from "@constants/prompts"; import { getContext } from "@context"; import { getTotalCost, useCostSummary } from "@costTracker"; import { useLogStartupTime } from "@hooks/useLogStartupTime"; import { addToHistory } from "@history"; import { useApiKeyVerification } from "@hooks/useApiKeyVerification"; import { useCancelRequest } from "@hooks/useCancelRequest"; import useCanUseTool from "@hooks/useCanUseTool"; import { useLogMessages } from "@hooks/useLogMessages"; import { PermissionProvider } from "@context/PermissionContext"; import { ModeIndicator } from "@components/ModeIndicator"; import { setMessagesGetter, setMessagesSetter, setModelConfigChangeHandler } from "@messages"; import { query } from "@query"; import { getGlobalConfig, saveGlobalConfig } from "@utils/config"; import { getNextAvailableLogForkNumber } from "@utils/log"; import { getErroredToolUseMessages, getInProgressToolUseIDs, getLastAssistantMessageId, getToolUseID, getUnresolvedToolUseIDs, INTERRUPT_MESSAGE, isNotEmptyMessage, normalizeMessages, normalizeMessagesForAPI, processUserInput, reorderMessages } from "@utils/messages"; import { ModelManager } from "@utils/model"; import { clearTerminal } from "@utils/terminal"; import { BinaryFeedback } from "@components/binary-feedback/BinaryFeedback"; import { getMaxThinkingTokens } from "@utils/thinking"; import { getOriginalCwd } from "@utils/state"; import { handleHashCommand } from "@commands/terminalSetup"; function REPL({ commands, safeMode, debug = false, initialForkNumber = 0, initialPrompt, messageLogName, shouldShowPromptInput, tools, verbose: verboseFromCLI, initialMessages, mcpClients = [], isDefaultModel = true, initialUpdateVersion, initialUpdateCommands }) { const [verboseConfig] = useState(() => verboseFromCLI ?? getGlobalConfig().verbose); const verbose = verboseConfig; const [forkNumber, setForkNumber] = useState( getNextAvailableLogForkNumber(messageLogName, initialForkNumber, 0) ); const [ forkConvoWithMessagesOnTheNextRender, setForkConvoWithMessagesOnTheNextRender ] = useState(null); const [abortController, setAbortController] = useState(null); const [isLoading, setIsLoading] = useState(false); const [toolJSX, setToolJSX] = useState(null); const [toolUseConfirm, setToolUseConfirm] = useState( null ); const [messages, setMessages] = useState(initialMessages ?? []); const [inputValue, setInputValue] = useState(""); const [inputMode, setInputMode] = useState( "prompt" ); const [submitCount, setSubmitCount] = useState(0); const [isMessageSelectorVisible, setIsMessageSelectorVisible] = useState(false); const [showCostDialog, setShowCostDialog] = useState(false); const [haveShownCostDialog, setHaveShownCostDialog] = useState( getGlobalConfig().hasAcknowledgedCostThreshold ); const [binaryFeedbackContext, setBinaryFeedbackContext] = useState(null); const updateAvailableVersion = initialUpdateVersion ?? null; const updateCommands = initialUpdateCommands ?? null; const getBinaryFeedbackResponse = useCallback( (m1, m2) => { return new Promise((resolvePromise) => { setBinaryFeedbackContext({ m1, m2, resolve: resolvePromise }); }); }, [] ); const readFileTimestamps = useRef({}); const { status: apiKeyStatus, reverify } = useApiKeyVerification(); function onCancel() { if (!isLoading) { return; } setIsLoading(false); if (toolUseConfirm) { toolUseConfirm.onAbort(); } else if (abortController && !abortController.signal.aborted) { abortController.abort(); } } useCancelRequest( setToolJSX, setToolUseConfirm, setBinaryFeedbackContext, onCancel, isLoading, isMessageSelectorVisible, abortController?.signal ); useEffect(() => { if (forkConvoWithMessagesOnTheNextRender) { setForkNumber((_) => _ + 1); setForkConvoWithMessagesOnTheNextRender(null); setMessages(forkConvoWithMessagesOnTheNextRender); } }, [forkConvoWithMessagesOnTheNextRender]); useEffect(() => { const totalCost = getTotalCost(); if (totalCost >= 5 && !showCostDialog && !haveShownCostDialog) { setShowCostDialog(true); } }, [messages, showCostDialog, haveShownCostDialog]); const canUseTool = useCanUseTool(setToolUseConfirm); async function onInit() { reverify(); if (!initialPrompt) { return; } setIsLoading(true); const newAbortController = new AbortController(); setAbortController(newAbortController); const model = new ModelManager(getGlobalConfig()).getModelName("main"); const newMessages = await processUserInput( initialPrompt, "prompt", setToolJSX, { abortController: newAbortController, options: { commands, forkNumber, messageLogName, tools, verbose, maxThinkingTokens: 0 }, messageId: getLastAssistantMessageId(messages), setForkConvoWithMessagesOnTheNextRender, readFileTimestamps: readFileTimestamps.current }, null ); if (newMessages.length) { for (const message of newMessages) { if (message.type === "user") { addToHistory(initialPrompt); } } setMessages((_) => [..._, ...newMessages]); const lastMessage = newMessages[newMessages.length - 1]; if (lastMessage.type === "assistant") { setAbortController(null); setIsLoading(false); return; } const [systemPrompt, context, model2, maxThinkingTokens] = await Promise.all([ getSystemPrompt(), getContext(), new ModelManager(getGlobalConfig()).getModelName("main"), getMaxThinkingTokens([...messages, ...newMessages]) ]); for await (const message of query( [...messages, ...newMessages], systemPrompt, context, canUseTool, { options: { commands, forkNumber, messageLogName, tools, verbose, safeMode, maxThinkingTokens }, messageId: getLastAssistantMessageId([...messages, ...newMessages]), readFileTimestamps: readFileTimestamps.current, abortController: newAbortController, setToolJSX }, getBinaryFeedbackResponse )) { setMessages((oldMessages) => [...oldMessages, message]); } } else { addToHistory(initialPrompt); } setHaveShownCostDialog( getGlobalConfig().hasAcknowledgedCostThreshold || false ); setIsLoading(false); setAbortController(null); } async function onQuery(newMessages, passedAbortController) { const controllerToUse = passedAbortController || new AbortController(); if (!passedAbortController) { setAbortController(controllerToUse); } const isKodingRequest = newMessages.length > 0 && newMessages[0].type === "user" && "options" in newMessages[0] && newMessages[0].options?.isKodingRequest === true; setMessages((oldMessages) => [...oldMessages, ...newMessages]); markProjectOnboardingComplete(); const lastMessage = newMessages[newMessages.length - 1]; if (lastMessage.type === "user" && typeof lastMessage.message.content === "string") { } if (lastMessage.type === "assistant") { setAbortController(null); setIsLoading(false); return; } const [systemPrompt, context, model, maxThinkingTokens] = await Promise.all([ getSystemPrompt(), getContext(), new ModelManager(getGlobalConfig()).getModelName("main"), getMaxThinkingTokens([...messages, lastMessage]) ]); let lastAssistantMessage = null; for await (const message of query( [...messages, lastMessage], systemPrompt, context, canUseTool, { options: { commands, forkNumber, messageLogName, tools, verbose, safeMode, maxThinkingTokens, // If this came from Koding mode, pass that along isKodingRequest: isKodingRequest || void 0 }, messageId: getLastAssistantMessageId([...messages, lastMessage]), readFileTimestamps: readFileTimestamps.current, abortController: controllerToUse, setToolJSX }, getBinaryFeedbackResponse )) { setMessages((oldMessages) => [...oldMessages, message]); if (message.type === "assistant") { lastAssistantMessage = message; } } if (isKodingRequest && lastAssistantMessage && lastAssistantMessage.type === "assistant") { try { const content = typeof lastAssistantMessage.message.content === "string" ? lastAssistantMessage.message.content : lastAssistantMessage.message.content.filter((block) => block.type === "text").map((block) => block.type === "text" ? block.text : "").join("\n"); if (content && content.trim().length > 0) { handleHashCommand(content); } } catch (error) { console.error("Error saving response to project docs:", error); } } setIsLoading(false); } useCostSummary(); useEffect(() => { const getMessages = () => messages; setMessagesGetter(getMessages); setMessagesSetter(setMessages); }, [messages]); useEffect(() => { setModelConfigChangeHandler(() => { setForkNumber((prev) => prev + 1); }); }, []); useLogMessages(messages, messageLogName, forkNumber); useLogStartupTime(); useEffect(() => { onInit(); }, []); const normalizedMessages = useMemo( () => normalizeMessages(messages).filter(isNotEmptyMessage), [messages] ); const unresolvedToolUseIDs = useMemo( () => getUnresolvedToolUseIDs(normalizedMessages), [normalizedMessages] ); const inProgressToolUseIDs = useMemo( () => getInProgressToolUseIDs(normalizedMessages), [normalizedMessages] ); const erroredToolUseIDs = useMemo( () => new Set( getErroredToolUseMessages(normalizedMessages).map( (_) => _.message.content[0].id ) ), [normalizedMessages] ); const messagesJSX = useMemo(() => { return [ { type: "static", jsx: /* @__PURE__ */ React.createElement(Box, { flexDirection: "column", key: `logo${forkNumber}` }, /* @__PURE__ */ React.createElement( Logo, { mcpClients, isDefaultModel, updateBannerVersion: updateAvailableVersion, updateBannerCommands: updateCommands } ), /* @__PURE__ */ React.createElement(ProjectOnboarding, { workspaceDir: getOriginalCwd() })) }, ...reorderMessages(normalizedMessages).map((_) => { const toolUseID = getToolUseID(_); const message = _.type === "progress" ? _.content.message.content[0]?.type === "text" && // TaskTool interrupts use Progress messages without extra �? // since <Message /> component already adds the margin _.content.message.content[0].text === INTERRUPT_MESSAGE ? /* @__PURE__ */ React.createElement( Message, { message: _.content, messages: _.normalizedMessages, addMargin: false, tools: _.tools, verbose: verbose ?? false, debug, erroredToolUseIDs: /* @__PURE__ */ new Set(), inProgressToolUseIDs: /* @__PURE__ */ new Set(), unresolvedToolUseIDs: /* @__PURE__ */ new Set(), shouldAnimate: false, shouldShowDot: false } ) : /* @__PURE__ */ React.createElement(MessageResponse, { children: /* @__PURE__ */ React.createElement( Message, { message: _.content, messages: _.normalizedMessages, addMargin: false, tools: _.tools, verbose: verbose ?? false, debug, erroredToolUseIDs: /* @__PURE__ */ new Set(), inProgressToolUseIDs: /* @__PURE__ */ new Set(), unresolvedToolUseIDs: /* @__PURE__ */ new Set([ _.content.message.content[0].id ]), shouldAnimate: false, shouldShowDot: false } ) }) : /* @__PURE__ */ React.createElement( Message, { message: _, messages: normalizedMessages, addMargin: true, tools, verbose, debug, erroredToolUseIDs, inProgressToolUseIDs, shouldAnimate: !toolJSX && !toolUseConfirm && !isMessageSelectorVisible && (!toolUseID || inProgressToolUseIDs.has(toolUseID)), shouldShowDot: true, unresolvedToolUseIDs } ); const type = shouldRenderStatically( _, normalizedMessages, unresolvedToolUseIDs ) ? "static" : "transient"; if (debug) { return { type, jsx: /* @__PURE__ */ React.createElement( Box, { borderStyle: "single", borderColor: type === "static" ? "green" : "red", key: _.uuid, width: "100%" }, message ) }; } return { type, jsx: /* @__PURE__ */ React.createElement(Box, { key: _.uuid, width: "100%" }, message) }; }) ]; }, [ forkNumber, normalizedMessages, tools, verbose, debug, erroredToolUseIDs, inProgressToolUseIDs, toolJSX, toolUseConfirm, isMessageSelectorVisible, unresolvedToolUseIDs, mcpClients, isDefaultModel ]); const showingCostDialog = !isLoading && showCostDialog; return /* @__PURE__ */ React.createElement( PermissionProvider, { isBypassPermissionsModeAvailable: !safeMode, children: /* @__PURE__ */ React.createElement(React.Fragment, null, /* @__PURE__ */ React.createElement(ModeIndicator, null), /* @__PURE__ */ React.createElement(React.Fragment, { key: `static-messages-${forkNumber}` }, /* @__PURE__ */ React.createElement( Static, { items: messagesJSX.filter((_) => _.type === "static"), children: (item) => item.jsx } )), messagesJSX.filter((_) => _.type === "transient").map((_) => _.jsx), /* @__PURE__ */ React.createElement( Box, { borderColor: "red", borderStyle: debug ? "single" : void 0, flexDirection: "column", width: "100%" }, !toolJSX && !toolUseConfirm && !binaryFeedbackContext && isLoading && /* @__PURE__ */ React.createElement(Spinner, null), toolJSX ? toolJSX.jsx : null, !toolJSX && binaryFeedbackContext && !isMessageSelectorVisible && /* @__PURE__ */ React.createElement( BinaryFeedback, { m1: binaryFeedbackContext.m1, m2: binaryFeedbackContext.m2, resolve: (result) => { binaryFeedbackContext.resolve(result); setTimeout(() => setBinaryFeedbackContext(null), 0); }, verbose, normalizedMessages, tools, debug, erroredToolUseIDs, inProgressToolUseIDs, unresolvedToolUseIDs } ), !toolJSX && toolUseConfirm && !isMessageSelectorVisible && !binaryFeedbackContext && /* @__PURE__ */ React.createElement( PermissionRequest, { toolUseConfirm, onDone: () => setToolUseConfirm(null), verbose } ), !toolJSX && !toolUseConfirm && !isMessageSelectorVisible && !binaryFeedbackContext && showingCostDialog && /* @__PURE__ */ React.createElement( CostThresholdDialog, { onDone: () => { setShowCostDialog(false); setHaveShownCostDialog(true); const projectConfig = getGlobalConfig(); saveGlobalConfig({ ...projectConfig, hasAcknowledgedCostThreshold: true }); } } ), !toolUseConfirm && !toolJSX?.shouldHidePromptInput && shouldShowPromptInput && !isMessageSelectorVisible && !binaryFeedbackContext && !showingCostDialog && /* @__PURE__ */ React.createElement(React.Fragment, null, /* @__PURE__ */ React.createElement( PromptInput, { commands, forkNumber, messageLogName, tools, isDisabled: apiKeyStatus === "invalid", isLoading, onQuery, debug, verbose, messages, setToolJSX, input: inputValue, onInputChange: setInputValue, mode: inputMode, onModeChange: setInputMode, submitCount, onSubmitCountChange: setSubmitCount, setIsLoading, setAbortController, onShowMessageSelector: () => setIsMessageSelectorVisible((prev) => !prev), setForkConvoWithMessagesOnTheNextRender, readFileTimestamps: readFileTimestamps.current, abortController, onModelChange: () => setForkNumber((prev) => prev + 1) } )) ), isMessageSelectorVisible && /* @__PURE__ */ React.createElement( MessageSelector, { erroredToolUseIDs, unresolvedToolUseIDs, messages: normalizeMessagesForAPI(messages), onSelect: async (message) => { setIsMessageSelectorVisible(false); if (!messages.includes(message)) { return; } onCancel(); setImmediate(async () => { await clearTerminal(); setMessages([]); setForkConvoWithMessagesOnTheNextRender( messages.slice(0, messages.indexOf(message)) ); if (typeof message.message.content === "string") { setInputValue(message.message.content); } }); }, onEscape: () => setIsMessageSelectorVisible(false), tools } ), /* @__PURE__ */ React.createElement(Newline, null)) } ); } function shouldRenderStatically(message, messages, unresolvedToolUseIDs) { switch (message.type) { case "user": case "assistant": { const toolUseID = getToolUseID(message); if (!toolUseID) { return true; } if (unresolvedToolUseIDs.has(toolUseID)) { return false; } const correspondingProgressMessage = messages.find( (_) => _.type === "progress" && _.toolUseID === toolUseID ); if (!correspondingProgressMessage) { return true; } return !intersects( unresolvedToolUseIDs, correspondingProgressMessage.siblingToolUseIDs ); } case "progress": return !intersects(unresolvedToolUseIDs, message.siblingToolUseIDs); } } function intersects(a, b) { return a.size > 0 && b.size > 0 && [...a].some((_) => b.has(_)); } export { REPL }; //# sourceMappingURL=REPL.js.map