UNPKG

@copilotkit/react-core

Version:

<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />

1,123 lines (996 loc) • 39.5 kB
import React, { useCallback, useEffect, useRef } from "react"; import { flushSync } from "react-dom"; import { FunctionCallHandler, COPILOT_CLOUD_PUBLIC_API_KEY_HEADER, CoAgentStateRenderHandler, randomId, parseJson, CopilotKitError, CopilotKitErrorCode, } from "@copilotkit/shared"; import { Message, TextMessage, ResultMessage, convertMessagesToGqlInput, filterAdjacentAgentStateMessages, filterAgentStateMessages, convertGqlOutputToMessages, MessageStatusCode, MessageRole, Role, CopilotRequestType, ForwardedParametersInput, loadMessagesFromJsonRepresentation, ExtensionsInput, CopilotRuntimeClient, langGraphInterruptEvent, MetaEvent, MetaEventName, ActionExecutionMessage, CopilotKitLangGraphInterruptEvent, LangGraphInterruptEvent, MetaEventInput, AgentStateInput, } from "@copilotkit/runtime-client-gql"; import { CopilotApiConfig } from "../context"; import { FrontendAction, processActionsForRuntimeRequest } from "../types/frontend-action"; import { CoagentState } from "../types/coagent-state"; import { AgentSession, useCopilotContext } from "../context/copilot-context"; import { useCopilotRuntimeClient } from "./use-copilot-runtime-client"; import { useAsyncCallback, useErrorToast } from "../components/error-boundary/error-utils"; import { useToast } from "../components/toast/toast-provider"; import { LangGraphInterruptAction, LangGraphInterruptActionSetter, } from "../types/interrupt-action"; export type UseChatOptions = { /** * System messages of the chat. Defaults to an empty array. */ initialMessages?: Message[]; /** * Callback function to be called when a function call is received. * If the function returns a `ChatRequest` object, the request will be sent * automatically to the API and will be used to update the chat. */ onFunctionCall?: FunctionCallHandler; /** * Callback function to be called when a coagent action is received. */ onCoAgentStateRender?: CoAgentStateRenderHandler; /** * Function definitions to be sent to the API. */ actions: FrontendAction<any>[]; /** * The CopilotKit API configuration. */ copilotConfig: CopilotApiConfig; /** * The current list of messages in the chat. */ messages: Message[]; /** * The setState-powered method to update the chat messages. */ setMessages: React.Dispatch<React.SetStateAction<Message[]>>; /** * A callback to get the latest system message. */ makeSystemMessageCallback: () => TextMessage; /** * Whether the API request is in progress */ isLoading: boolean; /** * setState-powered method to update the isChatLoading value */ setIsLoading: React.Dispatch<React.SetStateAction<boolean>>; /** * The current list of coagent states. */ coagentStatesRef: React.RefObject<Record<string, CoagentState>>; /** * setState-powered method to update the agent states */ setCoagentStatesWithRef: React.Dispatch<React.SetStateAction<Record<string, CoagentState>>>; /** * The current agent session. */ agentSession: AgentSession | null; /** * setState-powered method to update the agent session */ setAgentSession: React.Dispatch<React.SetStateAction<AgentSession | null>>; /** * The forwarded parameters. */ forwardedParameters?: Pick<ForwardedParametersInput, "temperature">; /** * The current thread ID. */ threadId: string; /** * set the current thread ID */ setThreadId: (threadId: string) => void; /** * The current run ID. */ runId: string | null; /** * set the current run ID */ setRunId: (runId: string | null) => void; /** * The global chat abort controller. */ chatAbortControllerRef: React.MutableRefObject<AbortController | null>; /** * The agent lock. */ agentLock: string | null; /** * The extensions. */ extensions: ExtensionsInput; /** * The setState-powered method to update the extensions. */ setExtensions: React.Dispatch<React.SetStateAction<ExtensionsInput>>; langGraphInterruptAction: LangGraphInterruptAction | null; setLangGraphInterruptAction: LangGraphInterruptActionSetter; }; export type UseChatHelpers = { /** * Append a user message to the chat list. This triggers the API call to fetch * the assistant's response. * @param message The message to append */ append: (message: Message, options?: AppendMessageOptions) => Promise<void>; /** * Reload the last AI chat response for the given chat history. If the last * message isn't from the assistant, it will request the API to generate a * new response. */ reload: (messageId: string) => Promise<void>; /** * Abort the current request immediately, keep the generated tokens if any. */ stop: () => void; /** * Run the chat completion. */ runChatCompletion: () => Promise<Message[]>; }; export interface AppendMessageOptions { /** * Whether to run the chat completion after appending the message. Defaults to `true`. */ followUp?: boolean; /** * Whether to clear the suggestions after appending the message. Defaults to `true`. */ clearSuggestions?: boolean; } export function useChat(options: UseChatOptions): UseChatHelpers { const { messages, setMessages, makeSystemMessageCallback, copilotConfig, setIsLoading, initialMessages, isLoading, actions, onFunctionCall, onCoAgentStateRender, setCoagentStatesWithRef, coagentStatesRef, agentSession, setAgentSession, threadId, setThreadId, runId, setRunId, chatAbortControllerRef, agentLock, extensions, setExtensions, langGraphInterruptAction, setLangGraphInterruptAction, } = options; const runChatCompletionRef = useRef<(previousMessages: Message[]) => Promise<Message[]>>(); const addErrorToast = useErrorToast(); const { setBannerError } = useToast(); // Get onError from context since it's not part of copilotConfig const { onError } = useCopilotContext(); // Add tracing functionality to use-chat const traceUIError = async (error: CopilotKitError, originalError?: any) => { // Just check if onError and publicApiKey are defined if (!onError || !copilotConfig?.publicApiKey) return; try { const traceEvent = { type: "error" as const, timestamp: Date.now(), context: { source: "ui" as const, request: { operation: "useChatCompletion", url: copilotConfig.chatApiEndpoint, startTime: Date.now(), }, technical: { environment: "browser", userAgent: typeof navigator !== "undefined" ? navigator.userAgent : undefined, stackTrace: originalError instanceof Error ? originalError.stack : undefined, }, }, error, }; await onError(traceEvent); } catch (traceError) { console.error("Error in use-chat onError handler:", traceError); } }; // We need to keep a ref of coagent states and session because of renderAndWait - making sure // the latest state is sent to the API // This is a workaround and needs to be addressed in the future const agentSessionRef = useRef<AgentSession | null>(agentSession); agentSessionRef.current = agentSession; const runIdRef = useRef<string | null>(runId); runIdRef.current = runId; const extensionsRef = useRef<ExtensionsInput>(extensions); extensionsRef.current = extensions; const publicApiKey = copilotConfig.publicApiKey; const headers = { ...(copilotConfig.headers || {}), ...(publicApiKey ? { [COPILOT_CLOUD_PUBLIC_API_KEY_HEADER]: publicApiKey } : {}), }; const { showDevConsole } = useCopilotContext(); const runtimeClient = useCopilotRuntimeClient({ url: copilotConfig.chatApiEndpoint, publicApiKey: copilotConfig.publicApiKey, headers, credentials: copilotConfig.credentials, showDevConsole, }); const pendingAppendsRef = useRef<{ message: Message; followUp: boolean }[]>([]); const runChatCompletion = useAsyncCallback( async (previousMessages: Message[]): Promise<Message[]> => { setIsLoading(true); const interruptEvent = langGraphInterruptAction?.event; // In case an interrupt event exist and valid but has no response yet, we cannot process further messages to an agent if ( interruptEvent?.name === MetaEventName.LangGraphInterruptEvent && interruptEvent?.value && !interruptEvent?.response && agentSessionRef.current ) { addErrorToast([ new Error( "A message was sent while interrupt is active. This will cause failure on the agent side", ), ]); } // this message is just a placeholder. It will disappear once the first real message // is received let newMessages: Message[] = [ new TextMessage({ content: "", role: Role.Assistant, }), ]; chatAbortControllerRef.current = new AbortController(); setMessages([...previousMessages, ...newMessages]); const systemMessage = makeSystemMessageCallback(); const messagesWithContext = [systemMessage, ...(initialMessages || []), ...previousMessages]; // ----- Set mcpServers in properties ----- // Create a copy of properties to avoid modifying the original object const finalProperties = { ...(copilotConfig.properties || {}) }; // Look for mcpServers in either direct property or properties let mcpServersToUse = null; // First check direct mcpServers property if ( copilotConfig.mcpServers && Array.isArray(copilotConfig.mcpServers) && copilotConfig.mcpServers.length > 0 ) { mcpServersToUse = copilotConfig.mcpServers; } // Then check mcpServers in properties else if ( copilotConfig.properties?.mcpServers && Array.isArray(copilotConfig.properties.mcpServers) && copilotConfig.properties.mcpServers.length > 0 ) { mcpServersToUse = copilotConfig.properties.mcpServers; } // Apply the mcpServers to properties if found if (mcpServersToUse) { // Set in finalProperties finalProperties.mcpServers = mcpServersToUse; // Also set in copilotConfig directly for future use copilotConfig.mcpServers = mcpServersToUse; } // ------------------------------------------------------------- const isAgentRun = agentSessionRef.current !== null; const stream = runtimeClient.asStream( runtimeClient.generateCopilotResponse({ data: { frontend: { actions: processActionsForRuntimeRequest(actions), url: window.location.href, }, threadId: threadId, runId: runIdRef.current, extensions: extensionsRef.current, metaEvents: composeAndFlushMetaEventsInput([langGraphInterruptAction?.event]), messages: convertMessagesToGqlInput(filterAgentStateMessages(messagesWithContext)), ...(copilotConfig.cloud ? { cloud: { ...(copilotConfig.cloud.guardrails?.input?.restrictToTopic?.enabled ? { guardrails: { inputValidationRules: { allowList: copilotConfig.cloud.guardrails.input.restrictToTopic.validTopics, denyList: copilotConfig.cloud.guardrails.input.restrictToTopic.invalidTopics, }, }, } : {}), }, } : {}), metadata: { requestType: CopilotRequestType.Chat, }, ...(agentSessionRef.current ? { agentSession: agentSessionRef.current, } : {}), agentStates: Object.values(coagentStatesRef.current!).map((state) => { const stateObject: AgentStateInput = { agentName: state.name, state: JSON.stringify(state.state), }; if (state.config !== undefined) { stateObject.config = JSON.stringify(state.config); } return stateObject; }), forwardedParameters: options.forwardedParameters || {}, }, properties: finalProperties, signal: chatAbortControllerRef.current?.signal, }), ); const guardrailsEnabled = copilotConfig.cloud?.guardrails?.input?.restrictToTopic.enabled || false; const reader = stream.getReader(); let executedCoAgentStateRenders: string[] = []; let followUp: FrontendAction["followUp"] = undefined; let messages: Message[] = []; let syncedMessages: Message[] = []; let interruptMessages: Message[] = []; try { while (true) { let done, value; try { const readResult = await reader.read(); done = readResult.done; value = readResult.value; } catch (readError) { break; } if (done) { if (chatAbortControllerRef.current.signal.aborted) { return []; } break; } if (!value?.generateCopilotResponse) { continue; } runIdRef.current = value.generateCopilotResponse.runId || null; // in the output, graphql inserts __typename, which leads to an error when sending it along // as input to the next request. extensionsRef.current = CopilotRuntimeClient.removeGraphQLTypename( value.generateCopilotResponse.extensions || {}, ); // setThreadId(threadIdRef.current); setRunId(runIdRef.current); setExtensions(extensionsRef.current); let rawMessagesResponse = value.generateCopilotResponse.messages; const metaEvents: MetaEvent[] | undefined = value.generateCopilotResponse?.metaEvents ?? []; (metaEvents ?? []).forEach((ev) => { if (ev.name === MetaEventName.LangGraphInterruptEvent) { let eventValue = langGraphInterruptEvent(ev as LangGraphInterruptEvent).value; eventValue = parseJson(eventValue, eventValue); setLangGraphInterruptAction({ event: { ...langGraphInterruptEvent(ev as LangGraphInterruptEvent), value: eventValue, }, }); } if (ev.name === MetaEventName.CopilotKitLangGraphInterruptEvent) { const data = (ev as CopilotKitLangGraphInterruptEvent).data; // @ts-expect-error -- same type of messages rawMessagesResponse = [...rawMessagesResponse, ...data.messages]; interruptMessages = convertGqlOutputToMessages( // @ts-ignore filterAdjacentAgentStateMessages(data.messages), ); } }); messages = convertGqlOutputToMessages( filterAdjacentAgentStateMessages(rawMessagesResponse), ); newMessages = []; // Handle error statuses BEFORE checking if there are messages // (errors can come in chunks with no messages) // request failed, display error message and quit if ( value.generateCopilotResponse.status?.__typename === "FailedResponseStatus" && value.generateCopilotResponse.status.reason === "GUARDRAILS_VALIDATION_FAILED" ) { const guardrailsReason = value.generateCopilotResponse.status.details?.guardrailsReason || ""; newMessages = [ new TextMessage({ role: MessageRole.Assistant, content: guardrailsReason, }), ]; // Trace guardrails validation failure const guardrailsError = new CopilotKitError({ message: `Guardrails validation failed: ${guardrailsReason}`, code: CopilotKitErrorCode.MISUSE, }); await traceUIError(guardrailsError, { statusReason: value.generateCopilotResponse.status.reason, statusDetails: value.generateCopilotResponse.status.details, }); setMessages([...previousMessages, ...newMessages]); break; } // Handle UNKNOWN_ERROR failures (like authentication errors) by routing to banner error system if ( value.generateCopilotResponse.status?.__typename === "FailedResponseStatus" && value.generateCopilotResponse.status.reason === "UNKNOWN_ERROR" ) { const errorMessage = value.generateCopilotResponse.status.details?.description || "An unknown error occurred"; // Try to extract original error information from the response details const statusDetails = value.generateCopilotResponse.status.details; const originalError = statusDetails?.originalError || statusDetails?.error; // Extract structured error information if available (prioritize top-level over extensions) const originalCode = originalError?.code || originalError?.extensions?.code; const originalSeverity = originalError?.severity || originalError?.extensions?.severity; const originalVisibility = originalError?.visibility || originalError?.extensions?.visibility; // Use the original error code if available, otherwise default to NETWORK_ERROR let errorCode = CopilotKitErrorCode.NETWORK_ERROR; if (originalCode && Object.values(CopilotKitErrorCode).includes(originalCode)) { errorCode = originalCode; } // Create a structured CopilotKitError preserving original error information const structuredError = new CopilotKitError({ message: errorMessage, code: errorCode, severity: originalSeverity, visibility: originalVisibility, }); // Display the error in the banner setBannerError(structuredError); // Trace the error for debugging/observability await traceUIError(structuredError, { statusReason: value.generateCopilotResponse.status.reason, statusDetails: value.generateCopilotResponse.status.details, originalErrorCode: originalCode, preservedStructure: !!originalCode, }); // Stop processing and break from the loop setIsLoading(false); break; } // add messages to the chat else if (messages.length > 0) { newMessages = [...messages]; for (const message of messages) { // execute onCoAgentStateRender handler if ( message.isAgentStateMessage() && !message.active && !executedCoAgentStateRenders.includes(message.id) && onCoAgentStateRender ) { // Do not execute a coagent action if guardrails are enabled but the status is not known if (guardrailsEnabled && value.generateCopilotResponse.status === undefined) { break; } // execute coagent action await onCoAgentStateRender({ name: message.agentName, nodeName: message.nodeName, state: message.state, }); executedCoAgentStateRenders.push(message.id); } } const lastAgentStateMessage = [...messages] .reverse() .find((message) => message.isAgentStateMessage()); if (lastAgentStateMessage) { if ( lastAgentStateMessage.state.messages && lastAgentStateMessage.state.messages.length > 0 ) { syncedMessages = loadMessagesFromJsonRepresentation( lastAgentStateMessage.state.messages, ); } setCoagentStatesWithRef((prevAgentStates) => ({ ...prevAgentStates, [lastAgentStateMessage.agentName]: { name: lastAgentStateMessage.agentName, state: lastAgentStateMessage.state, running: lastAgentStateMessage.running, active: lastAgentStateMessage.active, threadId: lastAgentStateMessage.threadId, nodeName: lastAgentStateMessage.nodeName, runId: lastAgentStateMessage.runId, // Preserve existing config from previous state config: prevAgentStates[lastAgentStateMessage.agentName]?.config, }, })); if (lastAgentStateMessage.running) { setAgentSession({ threadId: lastAgentStateMessage.threadId, agentName: lastAgentStateMessage.agentName, nodeName: lastAgentStateMessage.nodeName, }); } else { if (agentLock) { setAgentSession({ threadId: randomId(), agentName: agentLock, nodeName: undefined, }); } else { setAgentSession(null); } } } } if (newMessages.length > 0) { // Update message state setMessages([...previousMessages, ...newMessages]); } } let finalMessages = constructFinalMessages( [...syncedMessages, ...interruptMessages], previousMessages, newMessages, ); let didExecuteAction = false; // ----- Helper function to execute an action and manage its lifecycle ----- const executeActionFromMessage = async ( currentAction: FrontendAction<any>, actionMessage: ActionExecutionMessage, ) => { const isInterruptAction = interruptMessages.find((m) => m.id === actionMessage.id); // Determine follow-up behavior: use action's specific setting if defined, otherwise default based on interrupt status. followUp = currentAction?.followUp ?? !isInterruptAction; // Call _setActivatingMessageId before executing the action for HITL correlation if ((currentAction as any)?._setActivatingMessageId) { (currentAction as any)._setActivatingMessageId(actionMessage.id); } const resultMessage = await executeAction({ onFunctionCall: onFunctionCall!, message: actionMessage, chatAbortControllerRef, onError: (error: Error) => { addErrorToast([error]); // console.error is kept here as it's a genuine error in action execution console.error(`Failed to execute action ${actionMessage.name}: ${error}`); }, setMessages, getFinalMessages: () => finalMessages, isRenderAndWait: (currentAction as any)?._isRenderAndWait || false, }); didExecuteAction = true; const messageIndex = finalMessages.findIndex((msg) => msg.id === actionMessage.id); finalMessages.splice(messageIndex + 1, 0, resultMessage); // If the executed action was a renderAndWaitForResponse type, update messages immediately // to reflect its completion in the UI, making it interactive promptly. if ((currentAction as any)?._isRenderAndWait) { const messagesForImmediateUpdate = [...finalMessages]; flushSync(() => { setMessages(messagesForImmediateUpdate); }); } // Clear _setActivatingMessageId after the action is done if ((currentAction as any)?._setActivatingMessageId) { (currentAction as any)._setActivatingMessageId(null); } return resultMessage; }; // ---------------------------------------------------------------------- // execute regular action executions that are specific to the frontend (last actions) if (onFunctionCall) { // Find consecutive action execution messages at the end const lastMessages = []; for (let i = finalMessages.length - 1; i >= 0; i--) { const message = finalMessages[i]; if ( (message.isActionExecutionMessage() || message.isResultMessage()) && message.status.code !== MessageStatusCode.Pending ) { lastMessages.unshift(message); } else if (!message.isAgentStateMessage()) { break; } } for (const message of lastMessages) { // We update the message state before calling the handler so that the render // function can be called with `executing` state setMessages(finalMessages); const action = actions.find( (action) => action.name === (message as ActionExecutionMessage).name, ); if (action && action.available === "frontend") { // never execute frontend actions continue; } const currentResultMessagePairedFeAction = message.isResultMessage() ? getPairedFeAction(actions, message) : null; // execution message which has an action registered with the hook (remote availability): // execute that action first, and then the "paired FE action" if (action && message.isActionExecutionMessage()) { // For HITL actions, check if they've already been processed to avoid redundant handler calls. const isRenderAndWaitAction = (action as any)?._isRenderAndWait || false; const alreadyProcessed = isRenderAndWaitAction && finalMessages.some( (fm) => fm.isResultMessage() && fm.actionExecutionId === message.id, ); if (alreadyProcessed) { // Skip re-execution if already processed } else { // Call the single, externally defined executeActionFromMessage const resultMessage = await executeActionFromMessage( action, message as ActionExecutionMessage, ); const pairedFeAction = getPairedFeAction(actions, resultMessage); if (pairedFeAction) { const newExecutionMessage = new ActionExecutionMessage({ name: pairedFeAction.name, arguments: parseJson(resultMessage.result, resultMessage.result), status: message.status, createdAt: message.createdAt, parentMessageId: message.parentMessageId, }); // Call the single, externally defined executeActionFromMessage await executeActionFromMessage(pairedFeAction, newExecutionMessage); } } } else if (message.isResultMessage() && currentResultMessagePairedFeAction) { // Actions which are set up in runtime actions array: Grab the result, executed paired FE action with it as args. const newExecutionMessage = new ActionExecutionMessage({ name: currentResultMessagePairedFeAction.name, arguments: parseJson(message.result, message.result), status: message.status, createdAt: message.createdAt, }); finalMessages.push(newExecutionMessage); // Call the single, externally defined executeActionFromMessage await executeActionFromMessage( currentResultMessagePairedFeAction, newExecutionMessage, ); } } setMessages(finalMessages); } // Conditionally run chat completion again if followUp is not explicitly false // and an action was executed or the last message is a server-side result (for non-agent runs). if ( followUp !== false && (didExecuteAction || // the last message is a server side result (!isAgentRun && finalMessages.length && finalMessages[finalMessages.length - 1].isResultMessage())) && // the user did not stop generation !chatAbortControllerRef.current?.signal.aborted ) { // run the completion again and return the result // wait for next tick to make sure all the react state updates // - tried using react-dom's flushSync, but it did not work await new Promise((resolve) => setTimeout(resolve, 10)); return await runChatCompletionRef.current!(finalMessages); } else if (chatAbortControllerRef.current?.signal.aborted) { // filter out all the action execution messages that do not have a consecutive matching result message const repairedMessages = finalMessages.filter((message, actionExecutionIndex) => { if (message.isActionExecutionMessage()) { return finalMessages.find( (msg, resultIndex) => msg.isResultMessage() && msg.actionExecutionId === message.id && resultIndex === actionExecutionIndex + 1, ); } return true; }); const repairedMessageIds = repairedMessages.map((message) => message.id); setMessages(repairedMessages); // LangGraph needs two pieces of information to continue execution: // 1. The threadId // 2. The nodeName it came from // When stopping the agent, we don't know the nodeName the agent would have ended with // Therefore, we set the nodeName to the most reasonable thing we can guess, which // is "__end__" if (agentSessionRef.current?.nodeName) { setAgentSession({ threadId: agentSessionRef.current.threadId, agentName: agentSessionRef.current.agentName, nodeName: "__end__", }); } // only return new messages that were not filtered out return newMessages.filter((message) => repairedMessageIds.includes(message.id)); } else { return newMessages.slice(); } } finally { setIsLoading(false); } }, [ messages, setMessages, makeSystemMessageCallback, copilotConfig, setIsLoading, initialMessages, isLoading, actions, onFunctionCall, onCoAgentStateRender, setCoagentStatesWithRef, coagentStatesRef, agentSession, setAgentSession, ], ); runChatCompletionRef.current = runChatCompletion; const runChatCompletionAndHandleFunctionCall = useAsyncCallback( async (messages: Message[]): Promise<void> => { await runChatCompletionRef.current!(messages); }, [messages], ); useEffect(() => { if (!isLoading && pendingAppendsRef.current.length > 0) { const pending = pendingAppendsRef.current.splice(0); const followUp = pending.some((p) => p.followUp); const newMessages = [...messages, ...pending.map((p) => p.message)]; setMessages(newMessages); if (followUp) { runChatCompletionAndHandleFunctionCall(newMessages); } } }, [isLoading, messages, setMessages, runChatCompletionAndHandleFunctionCall]); // Go over all events and see that they include data that should be returned to the agent const composeAndFlushMetaEventsInput = useCallback( (metaEvents: (MetaEvent | undefined | null)[]) => { return metaEvents.reduce((acc: MetaEventInput[], event) => { if (!event) return acc; switch (event.name) { case MetaEventName.LangGraphInterruptEvent: if (event.response) { // Flush interrupt event from state setLangGraphInterruptAction(null); const value = (event as LangGraphInterruptEvent).value; return [ ...acc, { name: event.name, value: typeof value === "string" ? value : JSON.stringify(value), response: typeof event.response === "string" ? event.response : JSON.stringify(event.response), }, ]; } return acc; default: return acc; } }, []); }, [setLangGraphInterruptAction], ); const append = useAsyncCallback( async (message: Message, options?: AppendMessageOptions): Promise<void> => { const followUp = options?.followUp ?? true; if (isLoading) { pendingAppendsRef.current.push({ message, followUp }); return; } const newMessages = [...messages, message]; setMessages(newMessages); if (followUp) { return runChatCompletionAndHandleFunctionCall(newMessages); } }, [isLoading, messages, setMessages, runChatCompletionAndHandleFunctionCall], ); const reload = useAsyncCallback( async (reloadMessageId: string): Promise<void> => { if (isLoading || messages.length === 0) { return; } const reloadMessageIndex = messages.findIndex((msg) => msg.id === reloadMessageId); if (reloadMessageIndex === -1) { console.warn(`Message with id ${reloadMessageId} not found`); return; } // @ts-expect-error -- message has role const reloadMessageRole = messages[reloadMessageIndex].role; if (reloadMessageRole !== MessageRole.Assistant) { console.warn(`Regenerate cannot be performed on ${reloadMessageRole} role`); return; } let historyCutoff: Message[] = []; if (messages.length > 2) { // message to regenerate from is now first. // Work backwards to find the first the closest user message const lastUserMessageBeforeRegenerate = messages .slice(0, reloadMessageIndex) .reverse() .find( (msg) => // @ts-expect-error -- message has role msg.role === MessageRole.User, ); const indexOfLastUserMessageBeforeRegenerate = messages.findIndex( (msg) => msg.id === lastUserMessageBeforeRegenerate!.id, ); // Include the user message, remove everything after it historyCutoff = messages.slice(0, indexOfLastUserMessageBeforeRegenerate + 1); } setMessages(historyCutoff); return runChatCompletionAndHandleFunctionCall(historyCutoff); }, [isLoading, messages, setMessages, runChatCompletionAndHandleFunctionCall], ); const stop = (): void => { chatAbortControllerRef.current?.abort("Stop was called"); }; return { append, reload, stop, runChatCompletion: () => runChatCompletionRef.current!(messages), }; } function constructFinalMessages( syncedMessages: Message[], previousMessages: Message[], newMessages: Message[], ): Message[] { const finalMessages = syncedMessages.length > 0 ? [...syncedMessages] : [...previousMessages, ...newMessages]; if (syncedMessages.length > 0) { const messagesWithAgentState = [...previousMessages, ...newMessages]; let previousMessageId: string | undefined = undefined; for (const message of messagesWithAgentState) { if (message.isAgentStateMessage()) { // insert this message into finalMessages after the position of previousMessageId const index = finalMessages.findIndex((msg) => msg.id === previousMessageId); if (index !== -1) { finalMessages.splice(index + 1, 0, message); } } previousMessageId = message.id; } } return finalMessages; } async function executeAction({ onFunctionCall, message, chatAbortControllerRef, onError, setMessages, getFinalMessages, isRenderAndWait, }: { onFunctionCall: FunctionCallHandler; message: ActionExecutionMessage; chatAbortControllerRef: React.MutableRefObject<AbortController | null>; onError: (error: Error) => void; setMessages: React.Dispatch<React.SetStateAction<Message[]>>; getFinalMessages: () => Message[]; isRenderAndWait: boolean; }) { let result: any; let error: Error | null = null; const currentMessagesForHandler = getFinalMessages(); // The handler (onFunctionCall) runs its synchronous part here, potentially setting up // renderAndWaitRef.current for HITL actions via useCopilotAction's transformed handler. const handlerReturnedPromise = onFunctionCall({ messages: currentMessagesForHandler, name: message.name, args: message.arguments, }); // For HITL actions, call flushSync immediately after their handler has set up the promise // and before awaiting the promise. This ensures the UI updates to an interactive state. if (isRenderAndWait) { const currentMessagesForRender = getFinalMessages(); flushSync(() => { setMessages([...currentMessagesForRender]); }); } try { result = await Promise.race([ handlerReturnedPromise, // Await the promise returned by the handler new Promise((resolve) => chatAbortControllerRef.current?.signal.addEventListener("abort", () => resolve("Operation was aborted by the user"), ), ), // if the user stopped generation, we also abort consecutive actions new Promise((resolve) => { if (chatAbortControllerRef.current?.signal.aborted) { resolve("Operation was aborted by the user"); } }), ]); } catch (e) { onError(e as Error); } return new ResultMessage({ id: "result-" + message.id, result: ResultMessage.encodeResult( error ? { content: result, error: JSON.parse(JSON.stringify(error, Object.getOwnPropertyNames(error))), } : result, ), actionExecutionId: message.id, actionName: message.name, }); } function getPairedFeAction( actions: FrontendAction<any>[], message: ActionExecutionMessage | ResultMessage, ) { let actionName = null; if (message.isActionExecutionMessage()) { actionName = message.name; } else if (message.isResultMessage()) { actionName = message.actionName; } return actions.find( (action) => (action.name === actionName && action.available === "frontend") || action.pairedAction === actionName, ); }