pyb-ts
Version:
PYB-CLI - Minimal AI Agent with multi-model support and CLI interface
591 lines (590 loc) • 20.2 kB
JavaScript
import { Box, Newline, Static } from "ink";
import ProjectOnboarding, {
markProjectOnboardingComplete
} from "@components/ProjectOnboarding";
import { CostThresholdDialog } from "@components/CostThresholdDialog";
import * as React from "react";
import { useEffect, useMemo, useRef, useState, useCallback } from "react";
import { Logo } from "@components/Logo";
import { Message } from "@components/Message";
import { MessageResponse } from "@components/MessageResponse";
import { MessageSelector } from "@components/MessageSelector";
import {
PermissionRequest
} from "@components/permissions/PermissionRequest";
import PromptInput from "@components/PromptInput";
import { Spinner } from "@components/Spinner";
import { getSystemPrompt } from "@constants/prompts";
import { getContext } from "@context";
import { getTotalCost, useCostSummary } from "@costTracker";
import { useLogStartupTime } from "@hooks/useLogStartupTime";
import { addToHistory } from "@history";
import { useApiKeyVerification } from "@hooks/useApiKeyVerification";
import { useCancelRequest } from "@hooks/useCancelRequest";
import useCanUseTool from "@hooks/useCanUseTool";
import { useLogMessages } from "@hooks/useLogMessages";
import { PermissionProvider } from "@context/PermissionContext";
import { ModeIndicator } from "@components/ModeIndicator";
import {
setMessagesGetter,
setMessagesSetter,
setModelConfigChangeHandler
} from "@messages";
import {
query
} from "@query";
import { getGlobalConfig, saveGlobalConfig } from "@utils/config";
import { getNextAvailableLogForkNumber } from "@utils/log";
import {
getErroredToolUseMessages,
getInProgressToolUseIDs,
getLastAssistantMessageId,
getToolUseID,
getUnresolvedToolUseIDs,
INTERRUPT_MESSAGE,
isNotEmptyMessage,
normalizeMessages,
normalizeMessagesForAPI,
processUserInput,
reorderMessages
} from "@utils/messages";
import { ModelManager } from "@utils/model";
import { clearTerminal } from "@utils/terminal";
import { BinaryFeedback } from "@components/binary-feedback/BinaryFeedback";
import { getMaxThinkingTokens } from "@utils/thinking";
import { getOriginalCwd } from "@utils/state";
import { handleHashCommand } from "@commands/terminalSetup";
function REPL({
commands,
safeMode,
debug = false,
initialForkNumber = 0,
initialPrompt,
messageLogName,
shouldShowPromptInput,
tools,
verbose: verboseFromCLI,
initialMessages,
mcpClients = [],
isDefaultModel = true,
initialUpdateVersion,
initialUpdateCommands
}) {
const [verboseConfig] = useState(() => verboseFromCLI ?? getGlobalConfig().verbose);
const verbose = verboseConfig;
const [forkNumber, setForkNumber] = useState(
getNextAvailableLogForkNumber(messageLogName, initialForkNumber, 0)
);
const [
forkConvoWithMessagesOnTheNextRender,
setForkConvoWithMessagesOnTheNextRender
] = useState(null);
const [abortController, setAbortController] = useState(null);
const [isLoading, setIsLoading] = useState(false);
const [toolJSX, setToolJSX] = useState(null);
const [toolUseConfirm, setToolUseConfirm] = useState(
null
);
const [messages, setMessages] = useState(initialMessages ?? []);
const [inputValue, setInputValue] = useState("");
const [inputMode, setInputMode] = useState(
"prompt"
);
const [submitCount, setSubmitCount] = useState(0);
const [isMessageSelectorVisible, setIsMessageSelectorVisible] = useState(false);
const [showCostDialog, setShowCostDialog] = useState(false);
const [haveShownCostDialog, setHaveShownCostDialog] = useState(
getGlobalConfig().hasAcknowledgedCostThreshold
);
const [binaryFeedbackContext, setBinaryFeedbackContext] = useState(null);
const updateAvailableVersion = initialUpdateVersion ?? null;
const updateCommands = initialUpdateCommands ?? null;
const getBinaryFeedbackResponse = useCallback(
(m1, m2) => {
return new Promise((resolvePromise) => {
setBinaryFeedbackContext({
m1,
m2,
resolve: resolvePromise
});
});
},
[]
);
const readFileTimestamps = useRef({});
const { status: apiKeyStatus, reverify } = useApiKeyVerification();
function onCancel() {
if (!isLoading) {
return;
}
setIsLoading(false);
if (toolUseConfirm) {
toolUseConfirm.onAbort();
} else if (abortController && !abortController.signal.aborted) {
abortController.abort();
}
}
useCancelRequest(
setToolJSX,
setToolUseConfirm,
setBinaryFeedbackContext,
onCancel,
isLoading,
isMessageSelectorVisible,
abortController?.signal
);
useEffect(() => {
if (forkConvoWithMessagesOnTheNextRender) {
setForkNumber((_) => _ + 1);
setForkConvoWithMessagesOnTheNextRender(null);
setMessages(forkConvoWithMessagesOnTheNextRender);
}
}, [forkConvoWithMessagesOnTheNextRender]);
useEffect(() => {
const totalCost = getTotalCost();
if (totalCost >= 5 && !showCostDialog && !haveShownCostDialog) {
setShowCostDialog(true);
}
}, [messages, showCostDialog, haveShownCostDialog]);
const canUseTool = useCanUseTool(setToolUseConfirm);
async function onInit() {
reverify();
if (!initialPrompt) {
return;
}
setIsLoading(true);
const newAbortController = new AbortController();
setAbortController(newAbortController);
const model = new ModelManager(getGlobalConfig()).getModelName("main");
const newMessages = await processUserInput(
initialPrompt,
"prompt",
setToolJSX,
{
abortController: newAbortController,
options: {
commands,
forkNumber,
messageLogName,
tools,
verbose,
maxThinkingTokens: 0
},
messageId: getLastAssistantMessageId(messages),
setForkConvoWithMessagesOnTheNextRender,
readFileTimestamps: readFileTimestamps.current
},
null
);
if (newMessages.length) {
for (const message of newMessages) {
if (message.type === "user") {
addToHistory(initialPrompt);
}
}
setMessages((_) => [..._, ...newMessages]);
const lastMessage = newMessages[newMessages.length - 1];
if (lastMessage.type === "assistant") {
setAbortController(null);
setIsLoading(false);
return;
}
const [systemPrompt, context, model2, maxThinkingTokens] = await Promise.all([
getSystemPrompt(),
getContext(),
new ModelManager(getGlobalConfig()).getModelName("main"),
getMaxThinkingTokens([...messages, ...newMessages])
]);
for await (const message of query(
[...messages, ...newMessages],
systemPrompt,
context,
canUseTool,
{
options: {
commands,
forkNumber,
messageLogName,
tools,
verbose,
safeMode,
maxThinkingTokens
},
messageId: getLastAssistantMessageId([...messages, ...newMessages]),
readFileTimestamps: readFileTimestamps.current,
abortController: newAbortController,
setToolJSX
},
getBinaryFeedbackResponse
)) {
setMessages((oldMessages) => [...oldMessages, message]);
}
} else {
addToHistory(initialPrompt);
}
setHaveShownCostDialog(
getGlobalConfig().hasAcknowledgedCostThreshold || false
);
setIsLoading(false);
setAbortController(null);
}
async function onQuery(newMessages, passedAbortController) {
const controllerToUse = passedAbortController || new AbortController();
if (!passedAbortController) {
setAbortController(controllerToUse);
}
const isKodingRequest = newMessages.length > 0 && newMessages[0].type === "user" && "options" in newMessages[0] && newMessages[0].options?.isKodingRequest === true;
setMessages((oldMessages) => [...oldMessages, ...newMessages]);
markProjectOnboardingComplete();
const lastMessage = newMessages[newMessages.length - 1];
if (lastMessage.type === "user" && typeof lastMessage.message.content === "string") {
}
if (lastMessage.type === "assistant") {
setAbortController(null);
setIsLoading(false);
return;
}
const [systemPrompt, context, model, maxThinkingTokens] = await Promise.all([
getSystemPrompt(),
getContext(),
new ModelManager(getGlobalConfig()).getModelName("main"),
getMaxThinkingTokens([...messages, lastMessage])
]);
let lastAssistantMessage = null;
for await (const message of query(
[...messages, lastMessage],
systemPrompt,
context,
canUseTool,
{
options: {
commands,
forkNumber,
messageLogName,
tools,
verbose,
safeMode,
maxThinkingTokens,
// If this came from Koding mode, pass that along
isKodingRequest: isKodingRequest || void 0
},
messageId: getLastAssistantMessageId([...messages, lastMessage]),
readFileTimestamps: readFileTimestamps.current,
abortController: controllerToUse,
setToolJSX
},
getBinaryFeedbackResponse
)) {
setMessages((oldMessages) => [...oldMessages, message]);
if (message.type === "assistant") {
lastAssistantMessage = message;
}
}
if (isKodingRequest && lastAssistantMessage && lastAssistantMessage.type === "assistant") {
try {
const content = typeof lastAssistantMessage.message.content === "string" ? lastAssistantMessage.message.content : lastAssistantMessage.message.content.filter((block) => block.type === "text").map((block) => block.type === "text" ? block.text : "").join("\n");
if (content && content.trim().length > 0) {
handleHashCommand(content);
}
} catch (error) {
console.error("Error saving response to project docs:", error);
}
}
setIsLoading(false);
}
useCostSummary();
useEffect(() => {
const getMessages = () => messages;
setMessagesGetter(getMessages);
setMessagesSetter(setMessages);
}, [messages]);
useEffect(() => {
setModelConfigChangeHandler(() => {
setForkNumber((prev) => prev + 1);
});
}, []);
useLogMessages(messages, messageLogName, forkNumber);
useLogStartupTime();
useEffect(() => {
onInit();
}, []);
const normalizedMessages = useMemo(
() => normalizeMessages(messages).filter(isNotEmptyMessage),
[messages]
);
const unresolvedToolUseIDs = useMemo(
() => getUnresolvedToolUseIDs(normalizedMessages),
[normalizedMessages]
);
const inProgressToolUseIDs = useMemo(
() => getInProgressToolUseIDs(normalizedMessages),
[normalizedMessages]
);
const erroredToolUseIDs = useMemo(
() => new Set(
getErroredToolUseMessages(normalizedMessages).map(
(_) => _.message.content[0].id
)
),
[normalizedMessages]
);
const messagesJSX = useMemo(() => {
return [
{
type: "static",
jsx: /* @__PURE__ */ React.createElement(Box, { flexDirection: "column", key: `logo${forkNumber}` }, /* @__PURE__ */ React.createElement(
Logo,
{
mcpClients,
isDefaultModel,
updateBannerVersion: updateAvailableVersion,
updateBannerCommands: updateCommands
}
), /* @__PURE__ */ React.createElement(ProjectOnboarding, { workspaceDir: getOriginalCwd() }))
},
...reorderMessages(normalizedMessages).map((_) => {
const toolUseID = getToolUseID(_);
const message = _.type === "progress" ? _.content.message.content[0]?.type === "text" && // TaskTool interrupts use Progress messages without extra �?
// since <Message /> component already adds the margin
_.content.message.content[0].text === INTERRUPT_MESSAGE ? /* @__PURE__ */ React.createElement(
Message,
{
message: _.content,
messages: _.normalizedMessages,
addMargin: false,
tools: _.tools,
verbose: verbose ?? false,
debug,
erroredToolUseIDs: /* @__PURE__ */ new Set(),
inProgressToolUseIDs: /* @__PURE__ */ new Set(),
unresolvedToolUseIDs: /* @__PURE__ */ new Set(),
shouldAnimate: false,
shouldShowDot: false
}
) : /* @__PURE__ */ React.createElement(MessageResponse, { children: /* @__PURE__ */ React.createElement(
Message,
{
message: _.content,
messages: _.normalizedMessages,
addMargin: false,
tools: _.tools,
verbose: verbose ?? false,
debug,
erroredToolUseIDs: /* @__PURE__ */ new Set(),
inProgressToolUseIDs: /* @__PURE__ */ new Set(),
unresolvedToolUseIDs: /* @__PURE__ */ new Set([
_.content.message.content[0].id
]),
shouldAnimate: false,
shouldShowDot: false
}
) }) : /* @__PURE__ */ React.createElement(
Message,
{
message: _,
messages: normalizedMessages,
addMargin: true,
tools,
verbose,
debug,
erroredToolUseIDs,
inProgressToolUseIDs,
shouldAnimate: !toolJSX && !toolUseConfirm && !isMessageSelectorVisible && (!toolUseID || inProgressToolUseIDs.has(toolUseID)),
shouldShowDot: true,
unresolvedToolUseIDs
}
);
const type = shouldRenderStatically(
_,
normalizedMessages,
unresolvedToolUseIDs
) ? "static" : "transient";
if (debug) {
return {
type,
jsx: /* @__PURE__ */ React.createElement(
Box,
{
borderStyle: "single",
borderColor: type === "static" ? "green" : "red",
key: _.uuid,
width: "100%"
},
message
)
};
}
return {
type,
jsx: /* @__PURE__ */ React.createElement(Box, { key: _.uuid, width: "100%" }, message)
};
})
];
}, [
forkNumber,
normalizedMessages,
tools,
verbose,
debug,
erroredToolUseIDs,
inProgressToolUseIDs,
toolJSX,
toolUseConfirm,
isMessageSelectorVisible,
unresolvedToolUseIDs,
mcpClients,
isDefaultModel
]);
const showingCostDialog = !isLoading && showCostDialog;
return /* @__PURE__ */ React.createElement(
PermissionProvider,
{
isBypassPermissionsModeAvailable: !safeMode,
children: /* @__PURE__ */ React.createElement(React.Fragment, null, /* @__PURE__ */ React.createElement(ModeIndicator, null), /* @__PURE__ */ React.createElement(React.Fragment, { key: `static-messages-${forkNumber}` }, /* @__PURE__ */ React.createElement(
Static,
{
items: messagesJSX.filter((_) => _.type === "static"),
children: (item) => item.jsx
}
)), messagesJSX.filter((_) => _.type === "transient").map((_) => _.jsx), /* @__PURE__ */ React.createElement(
Box,
{
borderColor: "red",
borderStyle: debug ? "single" : void 0,
flexDirection: "column",
width: "100%"
},
!toolJSX && !toolUseConfirm && !binaryFeedbackContext && isLoading && /* @__PURE__ */ React.createElement(Spinner, null),
toolJSX ? toolJSX.jsx : null,
!toolJSX && binaryFeedbackContext && !isMessageSelectorVisible && /* @__PURE__ */ React.createElement(
BinaryFeedback,
{
m1: binaryFeedbackContext.m1,
m2: binaryFeedbackContext.m2,
resolve: (result) => {
binaryFeedbackContext.resolve(result);
setTimeout(() => setBinaryFeedbackContext(null), 0);
},
verbose,
normalizedMessages,
tools,
debug,
erroredToolUseIDs,
inProgressToolUseIDs,
unresolvedToolUseIDs
}
),
!toolJSX && toolUseConfirm && !isMessageSelectorVisible && !binaryFeedbackContext && /* @__PURE__ */ React.createElement(
PermissionRequest,
{
toolUseConfirm,
onDone: () => setToolUseConfirm(null),
verbose
}
),
!toolJSX && !toolUseConfirm && !isMessageSelectorVisible && !binaryFeedbackContext && showingCostDialog && /* @__PURE__ */ React.createElement(
CostThresholdDialog,
{
onDone: () => {
setShowCostDialog(false);
setHaveShownCostDialog(true);
const projectConfig = getGlobalConfig();
saveGlobalConfig({
...projectConfig,
hasAcknowledgedCostThreshold: true
});
}
}
),
!toolUseConfirm && !toolJSX?.shouldHidePromptInput && shouldShowPromptInput && !isMessageSelectorVisible && !binaryFeedbackContext && !showingCostDialog && /* @__PURE__ */ React.createElement(React.Fragment, null, /* @__PURE__ */ React.createElement(
PromptInput,
{
commands,
forkNumber,
messageLogName,
tools,
isDisabled: apiKeyStatus === "invalid",
isLoading,
onQuery,
debug,
verbose,
messages,
setToolJSX,
input: inputValue,
onInputChange: setInputValue,
mode: inputMode,
onModeChange: setInputMode,
submitCount,
onSubmitCountChange: setSubmitCount,
setIsLoading,
setAbortController,
onShowMessageSelector: () => setIsMessageSelectorVisible((prev) => !prev),
setForkConvoWithMessagesOnTheNextRender,
readFileTimestamps: readFileTimestamps.current,
abortController,
onModelChange: () => setForkNumber((prev) => prev + 1)
}
))
), isMessageSelectorVisible && /* @__PURE__ */ React.createElement(
MessageSelector,
{
erroredToolUseIDs,
unresolvedToolUseIDs,
messages: normalizeMessagesForAPI(messages),
onSelect: async (message) => {
setIsMessageSelectorVisible(false);
if (!messages.includes(message)) {
return;
}
onCancel();
setImmediate(async () => {
await clearTerminal();
setMessages([]);
setForkConvoWithMessagesOnTheNextRender(
messages.slice(0, messages.indexOf(message))
);
if (typeof message.message.content === "string") {
setInputValue(message.message.content);
}
});
},
onEscape: () => setIsMessageSelectorVisible(false),
tools
}
), /* @__PURE__ */ React.createElement(Newline, null))
}
);
}
function shouldRenderStatically(message, messages, unresolvedToolUseIDs) {
switch (message.type) {
case "user":
case "assistant": {
const toolUseID = getToolUseID(message);
if (!toolUseID) {
return true;
}
if (unresolvedToolUseIDs.has(toolUseID)) {
return false;
}
const correspondingProgressMessage = messages.find(
(_) => _.type === "progress" && _.toolUseID === toolUseID
);
if (!correspondingProgressMessage) {
return true;
}
return !intersects(
unresolvedToolUseIDs,
correspondingProgressMessage.siblingToolUseIDs
);
}
case "progress":
return !intersects(unresolvedToolUseIDs, message.siblingToolUseIDs);
}
}
function intersects(a, b) {
return a.size > 0 && b.size > 0 && [...a].some((_) => b.has(_));
}
export {
REPL
};
//# sourceMappingURL=REPL.js.map