UNPKG

@copilotkit/react-core

Version:

<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />

709 lines (707 loc) • 29 kB
import { processActionsForRuntimeRequest } from "./chunk-4CEQJ2X6.mjs"; import { useCopilotRuntimeClient } from "./chunk-Q3MCVRO3.mjs"; import { useAsyncCallback, useErrorToast } from "./chunk-N4WEHORG.mjs"; import { useToast } from "./chunk-EFL5OBKN.mjs"; import { useCopilotContext } from "./chunk-YHT6CWIY.mjs"; import { __async, __spreadProps, __spreadValues } from "./chunk-SKC7AJIV.mjs"; // src/hooks/use-chat.ts import { useCallback, useEffect, useRef } from "react"; import { flushSync } from "react-dom"; import { COPILOT_CLOUD_PUBLIC_API_KEY_HEADER, randomId, parseJson, CopilotKitError, CopilotKitErrorCode } from "@copilotkit/shared"; import { TextMessage, ResultMessage, convertMessagesToGqlInput, filterAdjacentAgentStateMessages, filterAgentStateMessages, convertGqlOutputToMessages, MessageStatusCode, MessageRole, Role, CopilotRequestType, loadMessagesFromJsonRepresentation, CopilotRuntimeClient, langGraphInterruptEvent, MetaEventName, ActionExecutionMessage } from "@copilotkit/runtime-client-gql"; function useChat(options) { const { messages, setMessages, makeSystemMessageCallback, copilotConfig, setIsLoading, initialMessages, isLoading, actions, onFunctionCall, onCoAgentStateRender, setCoagentStatesWithRef, coagentStatesRef, agentSession, setAgentSession, threadId, setThreadId, runId, setRunId, chatAbortControllerRef, agentLock, extensions, setExtensions, langGraphInterruptAction, setLangGraphInterruptAction, disableSystemMessage = false } = options; const runChatCompletionRef = useRef(); const addErrorToast = useErrorToast(); const { setBannerError } = useToast(); const { onError } = useCopilotContext(); const traceUIError = (error, originalError) => __async(this, null, function* () { try { const traceEvent = { type: "error", timestamp: Date.now(), context: { source: "ui", request: { operation: "useChatCompletion", url: copilotConfig.chatApiEndpoint, startTime: Date.now() }, technical: { environment: "browser", userAgent: typeof navigator !== "undefined" ? navigator.userAgent : void 0, stackTrace: originalError instanceof Error ? originalError.stack : void 0 } }, error }; yield onError(traceEvent); } catch (traceError) { console.error("Error in use-chat onError handler:", traceError); } }); const agentSessionRef = useRef(agentSession); agentSessionRef.current = agentSession; const runIdRef = useRef(runId); runIdRef.current = runId; const extensionsRef = useRef(extensions); extensionsRef.current = extensions; const publicApiKey = copilotConfig.publicApiKey; const headers = __spreadValues(__spreadValues({}, copilotConfig.headers || {}), publicApiKey ? { [COPILOT_CLOUD_PUBLIC_API_KEY_HEADER]: publicApiKey } : {}); const { showDevConsole } = useCopilotContext(); const runtimeClient = useCopilotRuntimeClient({ url: copilotConfig.chatApiEndpoint, publicApiKey: copilotConfig.publicApiKey, headers, credentials: copilotConfig.credentials, showDevConsole, onError }); const pendingAppendsRef = useRef([]); const runChatCompletion = useAsyncCallback( (previousMessages) => __async(this, null, function* () { var _a, _b, _c, _d, _e, _f, _g, _h, _i, _j, _k, _l, _m, _n, _o, _p, _q, _r, _s, _t; setIsLoading(true); const interruptEvent = langGraphInterruptAction == null ? void 0 : langGraphInterruptAction.event; if ((interruptEvent == null ? void 0 : interruptEvent.name) === MetaEventName.LangGraphInterruptEvent && (interruptEvent == null ? void 0 : interruptEvent.value) && !(interruptEvent == null ? void 0 : interruptEvent.response) && agentSessionRef.current) { addErrorToast([ new Error( "A message was sent while interrupt is active. This will cause failure on the agent side" ) ]); } let newMessages = [ new TextMessage({ content: "", role: Role.Assistant }) ]; chatAbortControllerRef.current = new AbortController(); setMessages([...previousMessages, ...newMessages]); const messagesWithContext = disableSystemMessage ? [...initialMessages || [], ...previousMessages] : [makeSystemMessageCallback(), ...initialMessages || [], ...previousMessages]; const finalProperties = __spreadValues({}, copilotConfig.properties || {}); let mcpServersToUse = null; if (copilotConfig.mcpServers && Array.isArray(copilotConfig.mcpServers) && copilotConfig.mcpServers.length > 0) { mcpServersToUse = copilotConfig.mcpServers; } else if (((_a = copilotConfig.properties) == null ? void 0 : _a.mcpServers) && Array.isArray(copilotConfig.properties.mcpServers) && copilotConfig.properties.mcpServers.length > 0) { mcpServersToUse = copilotConfig.properties.mcpServers; } if (mcpServersToUse) { finalProperties.mcpServers = mcpServersToUse; copilotConfig.mcpServers = mcpServersToUse; } const isAgentRun = agentSessionRef.current !== null; const stream = runtimeClient.asStream( runtimeClient.generateCopilotResponse({ data: __spreadProps(__spreadValues(__spreadProps(__spreadValues({ frontend: { actions: processActionsForRuntimeRequest(actions), url: window.location.href }, threadId, runId: runIdRef.current, extensions: extensionsRef.current, metaEvents: composeAndFlushMetaEventsInput([langGraphInterruptAction == null ? void 0 : langGraphInterruptAction.event]), messages: convertMessagesToGqlInput(filterAgentStateMessages(messagesWithContext)) }, copilotConfig.cloud ? { cloud: __spreadValues({}, ((_d = (_c = (_b = copilotConfig.cloud.guardrails) == null ? void 0 : _b.input) == null ? void 0 : _c.restrictToTopic) == null ? void 0 : _d.enabled) ? { guardrails: { inputValidationRules: { allowList: copilotConfig.cloud.guardrails.input.restrictToTopic.validTopics, denyList: copilotConfig.cloud.guardrails.input.restrictToTopic.invalidTopics } } } : {}) } : {}), { metadata: { requestType: CopilotRequestType.Chat } }), agentSessionRef.current ? { agentSession: agentSessionRef.current } : {}), { agentStates: Object.values(coagentStatesRef.current).map((state) => { const stateObject = { agentName: state.name, state: JSON.stringify(state.state) }; if (state.config !== void 0) { stateObject.config = JSON.stringify(state.config); } return stateObject; }), forwardedParameters: options.forwardedParameters || {} }), properties: finalProperties, signal: (_e = chatAbortControllerRef.current) == null ? void 0 : _e.signal }) ); const guardrailsEnabled = ((_h = (_g = (_f = copilotConfig.cloud) == null ? void 0 : _f.guardrails) == null ? void 0 : _g.input) == null ? void 0 : _h.restrictToTopic.enabled) || false; const reader = stream.getReader(); let executedCoAgentStateRenders = []; let followUp = void 0; let messages2 = []; let syncedMessages = []; let interruptMessages = []; try { while (true) { let done, value; try { const readResult = yield reader.read(); done = readResult.done; value = readResult.value; } catch (readError) { break; } if (done) { if (chatAbortControllerRef.current.signal.aborted) { return []; } break; } if (!(value == null ? void 0 : value.generateCopilotResponse)) { continue; } runIdRef.current = value.generateCopilotResponse.runId || null; extensionsRef.current = CopilotRuntimeClient.removeGraphQLTypename( value.generateCopilotResponse.extensions || {} ); setRunId(runIdRef.current); setExtensions(extensionsRef.current); let rawMessagesResponse = value.generateCopilotResponse.messages; const metaEvents = (_j = (_i = value.generateCopilotResponse) == null ? void 0 : _i.metaEvents) != null ? _j : []; (metaEvents != null ? metaEvents : []).forEach((ev) => { if (ev.name === MetaEventName.LangGraphInterruptEvent) { let eventValue = langGraphInterruptEvent(ev).value; eventValue = parseJson(eventValue, eventValue); setLangGraphInterruptAction({ event: __spreadProps(__spreadValues({}, langGraphInterruptEvent(ev)), { value: eventValue }) }); } if (ev.name === MetaEventName.CopilotKitLangGraphInterruptEvent) { const data = ev.data; rawMessagesResponse = [...rawMessagesResponse, ...data.messages]; interruptMessages = convertGqlOutputToMessages( // @ts-ignore filterAdjacentAgentStateMessages(data.messages) ); } }); messages2 = convertGqlOutputToMessages( filterAdjacentAgentStateMessages(rawMessagesResponse) ); newMessages = []; if (((_k = value.generateCopilotResponse.status) == null ? void 0 : _k.__typename) === "FailedResponseStatus" && value.generateCopilotResponse.status.reason === "GUARDRAILS_VALIDATION_FAILED") { const guardrailsReason = ((_l = value.generateCopilotResponse.status.details) == null ? void 0 : _l.guardrailsReason) || ""; newMessages = [ new TextMessage({ role: MessageRole.Assistant, content: guardrailsReason }) ]; const guardrailsError = new CopilotKitError({ message: `Guardrails validation failed: ${guardrailsReason}`, code: CopilotKitErrorCode.MISUSE }); yield traceUIError(guardrailsError, { statusReason: value.generateCopilotResponse.status.reason, statusDetails: value.generateCopilotResponse.status.details }); setMessages([...previousMessages, ...newMessages]); break; } if (((_m = value.generateCopilotResponse.status) == null ? void 0 : _m.__typename) === "FailedResponseStatus" && value.generateCopilotResponse.status.reason === "UNKNOWN_ERROR") { const errorMessage = ((_n = value.generateCopilotResponse.status.details) == null ? void 0 : _n.description) || "An unknown error occurred"; const statusDetails = value.generateCopilotResponse.status.details; const originalError = (statusDetails == null ? void 0 : statusDetails.originalError) || (statusDetails == null ? void 0 : statusDetails.error); const originalCode = (originalError == null ? void 0 : originalError.code) || ((_o = originalError == null ? void 0 : originalError.extensions) == null ? void 0 : _o.code); const originalSeverity = (originalError == null ? void 0 : originalError.severity) || ((_p = originalError == null ? void 0 : originalError.extensions) == null ? void 0 : _p.severity); const originalVisibility = (originalError == null ? void 0 : originalError.visibility) || ((_q = originalError == null ? void 0 : originalError.extensions) == null ? void 0 : _q.visibility); let errorCode = CopilotKitErrorCode.NETWORK_ERROR; if (originalCode && Object.values(CopilotKitErrorCode).includes(originalCode)) { errorCode = originalCode; } const structuredError = new CopilotKitError({ message: errorMessage, code: errorCode, severity: originalSeverity, visibility: originalVisibility }); setBannerError(structuredError); yield traceUIError(structuredError, { statusReason: value.generateCopilotResponse.status.reason, statusDetails: value.generateCopilotResponse.status.details, originalErrorCode: originalCode, preservedStructure: !!originalCode }); setIsLoading(false); break; } else if (messages2.length > 0) { newMessages = [...messages2]; for (const message of messages2) { if (message.isAgentStateMessage() && !message.active && !executedCoAgentStateRenders.includes(message.id) && onCoAgentStateRender) { if (guardrailsEnabled && value.generateCopilotResponse.status === void 0) { break; } yield onCoAgentStateRender({ name: message.agentName, nodeName: message.nodeName, state: message.state }); executedCoAgentStateRenders.push(message.id); } } const lastAgentStateMessage = [...messages2].reverse().find((message) => message.isAgentStateMessage()); if (lastAgentStateMessage) { if (lastAgentStateMessage.state.messages && lastAgentStateMessage.state.messages.length > 0) { syncedMessages = loadMessagesFromJsonRepresentation( lastAgentStateMessage.state.messages ); } setCoagentStatesWithRef((prevAgentStates) => { var _a2; return __spreadProps(__spreadValues({}, prevAgentStates), { [lastAgentStateMessage.agentName]: { name: lastAgentStateMessage.agentName, state: lastAgentStateMessage.state, running: lastAgentStateMessage.running, active: lastAgentStateMessage.active, threadId: lastAgentStateMessage.threadId, nodeName: lastAgentStateMessage.nodeName, runId: lastAgentStateMessage.runId, // Preserve existing config from previous state config: (_a2 = prevAgentStates[lastAgentStateMessage.agentName]) == null ? void 0 : _a2.config } }); }); if (lastAgentStateMessage.running) { setAgentSession({ threadId: lastAgentStateMessage.threadId, agentName: lastAgentStateMessage.agentName, nodeName: lastAgentStateMessage.nodeName }); } else { if (agentLock) { setAgentSession({ threadId: randomId(), agentName: agentLock, nodeName: void 0 }); } else { setAgentSession(null); } } } } if (newMessages.length > 0) { setMessages([...previousMessages, ...newMessages]); } } let finalMessages = constructFinalMessages( [...syncedMessages, ...interruptMessages], previousMessages, newMessages ); let didExecuteAction = false; const executeActionFromMessage = (currentAction, actionMessage) => __async(this, null, function* () { var _a2; const isInterruptAction = interruptMessages.find((m) => m.id === actionMessage.id); followUp = (_a2 = currentAction == null ? void 0 : currentAction.followUp) != null ? _a2 : !isInterruptAction; if (currentAction == null ? void 0 : currentAction._setActivatingMessageId) { currentAction._setActivatingMessageId(actionMessage.id); } const resultMessage = yield executeAction({ onFunctionCall, message: actionMessage, chatAbortControllerRef, onError: (error) => { addErrorToast([error]); console.error(`Failed to execute action ${actionMessage.name}: ${error}`); }, setMessages, getFinalMessages: () => finalMessages, isRenderAndWait: (currentAction == null ? void 0 : currentAction._isRenderAndWait) || false }); didExecuteAction = true; const messageIndex = finalMessages.findIndex((msg) => msg.id === actionMessage.id); finalMessages.splice(messageIndex + 1, 0, resultMessage); if (currentAction == null ? void 0 : currentAction._isRenderAndWait) { const messagesForImmediateUpdate = [...finalMessages]; flushSync(() => { setMessages(messagesForImmediateUpdate); }); } if (currentAction == null ? void 0 : currentAction._setActivatingMessageId) { currentAction._setActivatingMessageId(null); } return resultMessage; }); if (onFunctionCall) { const lastMessages = []; for (let i = finalMessages.length - 1; i >= 0; i--) { const message = finalMessages[i]; if ((message.isActionExecutionMessage() || message.isResultMessage()) && message.status.code !== MessageStatusCode.Pending) { lastMessages.unshift(message); } else if (!message.isAgentStateMessage()) { break; } } for (const message of lastMessages) { setMessages(finalMessages); const action = actions.find( (action2) => action2.name === message.name ); if (action && action.available === "frontend") { continue; } const currentResultMessagePairedFeAction = message.isResultMessage() ? getPairedFeAction(actions, message) : null; if (action && message.isActionExecutionMessage()) { const isRenderAndWaitAction = (action == null ? void 0 : action._isRenderAndWait) || false; const alreadyProcessed = isRenderAndWaitAction && finalMessages.some( (fm) => fm.isResultMessage() && fm.actionExecutionId === message.id ); if (alreadyProcessed) { } else { const resultMessage = yield executeActionFromMessage( action, message ); const pairedFeAction = getPairedFeAction(actions, resultMessage); if (pairedFeAction) { const newExecutionMessage = new ActionExecutionMessage({ name: pairedFeAction.name, arguments: parseJson(resultMessage.result, resultMessage.result), status: message.status, createdAt: message.createdAt, parentMessageId: message.parentMessageId }); yield executeActionFromMessage(pairedFeAction, newExecutionMessage); } } } else if (message.isResultMessage() && currentResultMessagePairedFeAction) { const newExecutionMessage = new ActionExecutionMessage({ name: currentResultMessagePairedFeAction.name, arguments: parseJson(message.result, message.result), status: message.status, createdAt: message.createdAt }); finalMessages.push(newExecutionMessage); yield executeActionFromMessage( currentResultMessagePairedFeAction, newExecutionMessage ); } } setMessages(finalMessages); } if (followUp !== false && (didExecuteAction || // the last message is a server side result !isAgentRun && finalMessages.length && finalMessages[finalMessages.length - 1].isResultMessage()) && // the user did not stop generation !((_r = chatAbortControllerRef.current) == null ? void 0 : _r.signal.aborted)) { yield new Promise((resolve) => setTimeout(resolve, 10)); return yield runChatCompletionRef.current(finalMessages); } else if ((_s = chatAbortControllerRef.current) == null ? void 0 : _s.signal.aborted) { const repairedMessages = finalMessages.filter((message, actionExecutionIndex) => { if (message.isActionExecutionMessage()) { return finalMessages.find( (msg, resultIndex) => msg.isResultMessage() && msg.actionExecutionId === message.id && resultIndex === actionExecutionIndex + 1 ); } return true; }); const repairedMessageIds = repairedMessages.map((message) => message.id); setMessages(repairedMessages); if ((_t = agentSessionRef.current) == null ? void 0 : _t.nodeName) { setAgentSession({ threadId: agentSessionRef.current.threadId, agentName: agentSessionRef.current.agentName, nodeName: "__end__" }); } return newMessages.filter((message) => repairedMessageIds.includes(message.id)); } else { return newMessages.slice(); } } finally { setIsLoading(false); } }), [ messages, setMessages, makeSystemMessageCallback, copilotConfig, setIsLoading, initialMessages, isLoading, actions, onFunctionCall, onCoAgentStateRender, setCoagentStatesWithRef, coagentStatesRef, agentSession, setAgentSession, disableSystemMessage ] ); runChatCompletionRef.current = runChatCompletion; const runChatCompletionAndHandleFunctionCall = useAsyncCallback( (messages2) => __async(this, null, function* () { yield runChatCompletionRef.current(messages2); }), [messages] ); useEffect(() => { if (!isLoading && pendingAppendsRef.current.length > 0) { const pending = pendingAppendsRef.current.splice(0); const followUp = pending.some((p) => p.followUp); const newMessages = [...messages, ...pending.map((p) => p.message)]; setMessages(newMessages); if (followUp) { runChatCompletionAndHandleFunctionCall(newMessages); } } }, [isLoading, messages, setMessages, runChatCompletionAndHandleFunctionCall]); const composeAndFlushMetaEventsInput = useCallback( (metaEvents) => { return metaEvents.reduce((acc, event) => { if (!event) return acc; switch (event.name) { case MetaEventName.LangGraphInterruptEvent: if (event.response) { setLangGraphInterruptAction(null); const value = event.value; return [ ...acc, { name: event.name, value: typeof value === "string" ? value : JSON.stringify(value), response: typeof event.response === "string" ? event.response : JSON.stringify(event.response) } ]; } return acc; default: return acc; } }, []); }, [setLangGraphInterruptAction] ); const append = useAsyncCallback( (message, options2) => __async(this, null, function* () { var _a; const followUp = (_a = options2 == null ? void 0 : options2.followUp) != null ? _a : true; if (isLoading) { pendingAppendsRef.current.push({ message, followUp }); return; } const newMessages = [...messages, message]; setMessages(newMessages); if (followUp) { return runChatCompletionAndHandleFunctionCall(newMessages); } }), [isLoading, messages, setMessages, runChatCompletionAndHandleFunctionCall] ); const reload = useAsyncCallback( (reloadMessageId) => __async(this, null, function* () { if (isLoading || messages.length === 0) { return; } const reloadMessageIndex = messages.findIndex((msg) => msg.id === reloadMessageId); if (reloadMessageIndex === -1) { console.warn(`Message with id ${reloadMessageId} not found`); return; } const reloadMessageRole = messages[reloadMessageIndex].role; if (reloadMessageRole !== MessageRole.Assistant) { console.warn(`Regenerate cannot be performed on ${reloadMessageRole} role`); return; } let historyCutoff = []; if (messages.length > 2) { const lastUserMessageBeforeRegenerate = messages.slice(0, reloadMessageIndex).reverse().find( (msg) => ( // @ts-expect-error -- message has role msg.role === MessageRole.User ) ); const indexOfLastUserMessageBeforeRegenerate = messages.findIndex( (msg) => msg.id === lastUserMessageBeforeRegenerate.id ); historyCutoff = messages.slice(0, indexOfLastUserMessageBeforeRegenerate + 1); } setMessages(historyCutoff); return runChatCompletionAndHandleFunctionCall(historyCutoff); }), [isLoading, messages, setMessages, runChatCompletionAndHandleFunctionCall] ); const stop = () => { var _a; (_a = chatAbortControllerRef.current) == null ? void 0 : _a.abort("Stop was called"); }; return { append, reload, stop, runChatCompletion: () => runChatCompletionRef.current(messages) }; } function constructFinalMessages(syncedMessages, previousMessages, newMessages) { const finalMessages = syncedMessages.length > 0 ? [...syncedMessages] : [...previousMessages, ...newMessages]; if (syncedMessages.length > 0) { const messagesWithAgentState = [...previousMessages, ...newMessages]; let previousMessageId = void 0; for (const message of messagesWithAgentState) { if (message.isAgentStateMessage()) { const index = finalMessages.findIndex((msg) => msg.id === previousMessageId); if (index !== -1) { finalMessages.splice(index + 1, 0, message); } } previousMessageId = message.id; } } return finalMessages; } function executeAction(_0) { return __async(this, arguments, function* ({ onFunctionCall, message, chatAbortControllerRef, onError, setMessages, getFinalMessages, isRenderAndWait }) { let result; let error = null; const currentMessagesForHandler = getFinalMessages(); const handlerReturnedPromise = onFunctionCall({ messages: currentMessagesForHandler, name: message.name, args: message.arguments }); if (isRenderAndWait) { const currentMessagesForRender = getFinalMessages(); flushSync(() => { setMessages([...currentMessagesForRender]); }); } try { result = yield Promise.race([ handlerReturnedPromise, // Await the promise returned by the handler new Promise( (resolve) => { var _a; return (_a = chatAbortControllerRef.current) == null ? void 0 : _a.signal.addEventListener( "abort", () => resolve("Operation was aborted by the user") ); } ), // if the user stopped generation, we also abort consecutive actions new Promise((resolve) => { var _a; if ((_a = chatAbortControllerRef.current) == null ? void 0 : _a.signal.aborted) { resolve("Operation was aborted by the user"); } }) ]); } catch (e) { onError(e); } return new ResultMessage({ id: "result-" + message.id, result: ResultMessage.encodeResult( error ? { content: result, error: JSON.parse(JSON.stringify(error, Object.getOwnPropertyNames(error))) } : result ), actionExecutionId: message.id, actionName: message.name }); }); } function getPairedFeAction(actions, message) { let actionName = null; if (message.isActionExecutionMessage()) { actionName = message.name; } else if (message.isResultMessage()) { actionName = message.actionName; } return actions.find( (action) => action.name === actionName && action.available === "frontend" || action.pairedAction === actionName ); } export { useChat }; //# sourceMappingURL=chunk-LLLCUHOO.mjs.map