UNPKG

@copilotkit/react-core

Version:

<div align="center"> <a href="https://copilotkit.ai" target="_blank"> <img src="https://github.com/copilotkit/copilotkit/raw/main/assets/banner.png" alt="CopilotKit Logo"> </a>

547 lines (545 loc) 21 kB
import { processActionsForRuntimeRequest } from "./chunk-4CEQJ2X6.mjs"; import { useCopilotRuntimeClient } from "./chunk-BKTARDXX.mjs"; import { useAsyncCallback, useErrorToast } from "./chunk-22ENANUU.mjs"; import { __async, __spreadProps, __spreadValues } from "./chunk-SKC7AJIV.mjs"; // src/hooks/use-chat.ts import { useCallback, useRef } from "react"; import { COPILOT_CLOUD_PUBLIC_API_KEY_HEADER, randomId, parseJson } from "@copilotkit/shared"; import { TextMessage, ResultMessage, convertMessagesToGqlInput, filterAdjacentAgentStateMessages, filterAgentStateMessages, convertGqlOutputToMessages, MessageStatusCode, MessageRole, Role, CopilotRequestType, loadMessagesFromJsonRepresentation, CopilotRuntimeClient, langGraphInterruptEvent, MetaEventName, ActionExecutionMessage } from "@copilotkit/runtime-client-gql"; function useChat(options) { const { messages, setMessages, makeSystemMessageCallback, copilotConfig, setIsLoading, initialMessages, isLoading, actions, onFunctionCall, onCoAgentStateRender, setCoagentStatesWithRef, coagentStatesRef, agentSession, setAgentSession, threadId, setThreadId, runId, setRunId, chatAbortControllerRef, agentLock, extensions, setExtensions, langGraphInterruptAction, setLangGraphInterruptAction } = options; const runChatCompletionRef = useRef(); const addErrorToast = useErrorToast(); const agentSessionRef = useRef(agentSession); agentSessionRef.current = agentSession; const runIdRef = useRef(runId); runIdRef.current = runId; const extensionsRef = useRef(extensions); extensionsRef.current = extensions; const publicApiKey = copilotConfig.publicApiKey; const headers = __spreadValues(__spreadValues({}, copilotConfig.headers || {}), publicApiKey ? { [COPILOT_CLOUD_PUBLIC_API_KEY_HEADER]: publicApiKey } : {}); const runtimeClient = useCopilotRuntimeClient({ url: copilotConfig.chatApiEndpoint, publicApiKey: copilotConfig.publicApiKey, headers, credentials: copilotConfig.credentials }); const runChatCompletion = useAsyncCallback( (previousMessages) => __async(this, null, function* () { var _a, _b, _c, _d, _e, _f, _g, _h, _i, _j, _k, _l, _m, _n; setIsLoading(true); const interruptEvent = langGraphInterruptAction == null ? void 0 : langGraphInterruptAction.event; if ((interruptEvent == null ? void 0 : interruptEvent.name) === MetaEventName.LangGraphInterruptEvent && (interruptEvent == null ? void 0 : interruptEvent.value) && !(interruptEvent == null ? void 0 : interruptEvent.response) && agentSessionRef.current) { addErrorToast([ new Error( "A message was sent while interrupt is active. This will cause failure on the agent side" ) ]); } let newMessages = [ new TextMessage({ content: "", role: Role.Assistant }) ]; chatAbortControllerRef.current = new AbortController(); setMessages([...previousMessages, ...newMessages]); const systemMessage = makeSystemMessageCallback(); const messagesWithContext = [systemMessage, ...initialMessages || [], ...previousMessages]; const isAgentRun = agentSessionRef.current !== null; const stream = runtimeClient.asStream( runtimeClient.generateCopilotResponse({ data: __spreadProps(__spreadValues(__spreadProps(__spreadValues({ frontend: { actions: processActionsForRuntimeRequest(actions), url: window.location.href }, threadId, runId: runIdRef.current, extensions: extensionsRef.current, metaEvents: composeAndFlushMetaEventsInput([langGraphInterruptAction == null ? void 0 : langGraphInterruptAction.event]), messages: convertMessagesToGqlInput(filterAgentStateMessages(messagesWithContext)) }, copilotConfig.cloud ? { cloud: __spreadValues({}, ((_c = (_b = (_a = copilotConfig.cloud.guardrails) == null ? void 0 : _a.input) == null ? void 0 : _b.restrictToTopic) == null ? void 0 : _c.enabled) ? { guardrails: { inputValidationRules: { allowList: copilotConfig.cloud.guardrails.input.restrictToTopic.validTopics, denyList: copilotConfig.cloud.guardrails.input.restrictToTopic.invalidTopics } } } : {}) } : {}), { metadata: { requestType: CopilotRequestType.Chat } }), agentSessionRef.current ? { agentSession: agentSessionRef.current } : {}), { agentStates: Object.values(coagentStatesRef.current).map((state) => { var _a2; return { agentName: state.name, state: JSON.stringify(state.state), configurable: JSON.stringify((_a2 = state.configurable) != null ? _a2 : {}) }; }), forwardedParameters: options.forwardedParameters || {} }), properties: copilotConfig.properties, signal: (_d = chatAbortControllerRef.current) == null ? void 0 : _d.signal }) ); const guardrailsEnabled = ((_g = (_f = (_e = copilotConfig.cloud) == null ? void 0 : _e.guardrails) == null ? void 0 : _f.input) == null ? void 0 : _g.restrictToTopic.enabled) || false; const reader = stream.getReader(); let executedCoAgentStateRenders = []; let followUp = void 0; let messages2 = []; let syncedMessages = []; let interruptMessages = []; try { while (true) { let done, value; try { const readResult = yield reader.read(); done = readResult.done; value = readResult.value; } catch (readError) { break; } if (done) { if (chatAbortControllerRef.current.signal.aborted) { return []; } break; } if (!(value == null ? void 0 : value.generateCopilotResponse)) { continue; } runIdRef.current = value.generateCopilotResponse.runId || null; extensionsRef.current = CopilotRuntimeClient.removeGraphQLTypename( value.generateCopilotResponse.extensions || {} ); setRunId(runIdRef.current); setExtensions(extensionsRef.current); let rawMessagesResponse = value.generateCopilotResponse.messages; ((_i = (_h = value.generateCopilotResponse) == null ? void 0 : _h.metaEvents) != null ? _i : []).forEach((ev) => { if (ev.name === MetaEventName.LangGraphInterruptEvent) { let eventValue = langGraphInterruptEvent(ev).value; eventValue = parseJson(eventValue, eventValue); setLangGraphInterruptAction({ event: __spreadProps(__spreadValues({}, langGraphInterruptEvent(ev)), { value: eventValue }) }); } if (ev.name === MetaEventName.CopilotKitLangGraphInterruptEvent) { const data = ev.data; rawMessagesResponse = [...rawMessagesResponse, ...data.messages]; interruptMessages = convertGqlOutputToMessages( // @ts-ignore filterAdjacentAgentStateMessages(data.messages) ); } }); messages2 = convertGqlOutputToMessages( filterAdjacentAgentStateMessages(rawMessagesResponse) ); if (messages2.length === 0) { continue; } newMessages = []; if (((_j = value.generateCopilotResponse.status) == null ? void 0 : _j.__typename) === "FailedResponseStatus" && value.generateCopilotResponse.status.reason === "GUARDRAILS_VALIDATION_FAILED") { newMessages = [ new TextMessage({ role: MessageRole.Assistant, content: ((_k = value.generateCopilotResponse.status.details) == null ? void 0 : _k.guardrailsReason) || "" }) ]; setMessages([...previousMessages, ...newMessages]); break; } else { newMessages = [...messages2]; for (const message of messages2) { if (message.isAgentStateMessage() && !message.active && !executedCoAgentStateRenders.includes(message.id) && onCoAgentStateRender) { if (guardrailsEnabled && value.generateCopilotResponse.status === void 0) { break; } yield onCoAgentStateRender({ name: message.agentName, nodeName: message.nodeName, state: message.state }); executedCoAgentStateRenders.push(message.id); } } const lastAgentStateMessage = [...messages2].reverse().find((message) => message.isAgentStateMessage()); if (lastAgentStateMessage) { if (lastAgentStateMessage.state.messages && lastAgentStateMessage.state.messages.length > 0) { syncedMessages = loadMessagesFromJsonRepresentation( lastAgentStateMessage.state.messages ); } setCoagentStatesWithRef((prevAgentStates) => __spreadProps(__spreadValues({}, prevAgentStates), { [lastAgentStateMessage.agentName]: { name: lastAgentStateMessage.agentName, state: lastAgentStateMessage.state, running: lastAgentStateMessage.running, active: lastAgentStateMessage.active, threadId: lastAgentStateMessage.threadId, nodeName: lastAgentStateMessage.nodeName, runId: lastAgentStateMessage.runId } })); if (lastAgentStateMessage.running) { setAgentSession({ threadId: lastAgentStateMessage.threadId, agentName: lastAgentStateMessage.agentName, nodeName: lastAgentStateMessage.nodeName }); } else { if (agentLock) { setAgentSession({ threadId: randomId(), agentName: agentLock, nodeName: void 0 }); } else { setAgentSession(null); } } } } if (newMessages.length > 0) { setMessages([...previousMessages, ...newMessages]); } } let finalMessages = constructFinalMessages( [...syncedMessages, ...interruptMessages], previousMessages, newMessages ); let didExecuteAction = false; if (onFunctionCall) { const lastMessages = []; for (let i = finalMessages.length - 1; i >= 0; i--) { const message = finalMessages[i]; if ((message.isActionExecutionMessage() || message.isResultMessage()) && message.status.code !== MessageStatusCode.Pending) { lastMessages.unshift(message); } else { break; } } for (const message of lastMessages) { setMessages(finalMessages); const action = actions.find( (action2) => action2.name === message.name ); const currentResultMessagePairedFeAction = message.isResultMessage() ? getPairedFeAction(actions, message) : null; const executeActionFromMessage = (action2, message2) => __async(this, null, function* () { followUp = action2 == null ? void 0 : action2.followUp; const resultMessage = yield executeAction({ onFunctionCall, previousMessages, message: message2, chatAbortControllerRef, onError: (error) => { addErrorToast([error]); console.error(`Failed to execute action ${message2.name}: ${error}`); } }); didExecuteAction = true; const messageIndex = finalMessages.findIndex((msg) => msg.id === message2.id); finalMessages.splice(messageIndex + 1, 0, resultMessage); return resultMessage; }); if (action && message.isActionExecutionMessage()) { const resultMessage = yield executeActionFromMessage(action, message); const pairedFeAction = getPairedFeAction(actions, resultMessage); if (pairedFeAction) { const newExecutionMessage = new ActionExecutionMessage({ name: pairedFeAction.name, arguments: parseJson(resultMessage.result, resultMessage.result), status: message.status, createdAt: message.createdAt, parentMessageId: message.parentMessageId }); yield executeActionFromMessage(pairedFeAction, newExecutionMessage); } } else if (message.isResultMessage() && currentResultMessagePairedFeAction) { const newExecutionMessage = new ActionExecutionMessage({ name: currentResultMessagePairedFeAction.name, arguments: parseJson(message.result, message.result), status: message.status, createdAt: message.createdAt }); finalMessages.push(newExecutionMessage); yield executeActionFromMessage( currentResultMessagePairedFeAction, newExecutionMessage ); } } setMessages(finalMessages); } if ( // if followUp is not explicitly false followUp !== false && // and we executed an action (didExecuteAction || // the last message is a server side result !isAgentRun && finalMessages.length && finalMessages[finalMessages.length - 1].isResultMessage()) && // the user did not stop generation !((_l = chatAbortControllerRef.current) == null ? void 0 : _l.signal.aborted) ) { yield new Promise((resolve) => setTimeout(resolve, 10)); return yield runChatCompletionRef.current(finalMessages); } else if ((_m = chatAbortControllerRef.current) == null ? void 0 : _m.signal.aborted) { const repairedMessages = finalMessages.filter((message, actionExecutionIndex) => { if (message.isActionExecutionMessage()) { return finalMessages.find( (msg, resultIndex) => msg.isResultMessage() && msg.actionExecutionId === message.id && resultIndex === actionExecutionIndex + 1 ); } return true; }); const repairedMessageIds = repairedMessages.map((message) => message.id); setMessages(repairedMessages); if ((_n = agentSessionRef.current) == null ? void 0 : _n.nodeName) { setAgentSession({ threadId: agentSessionRef.current.threadId, agentName: agentSessionRef.current.agentName, nodeName: "__end__" }); } return newMessages.filter((message) => repairedMessageIds.includes(message.id)); } else { return newMessages.slice(); } } finally { setIsLoading(false); } }), [ messages, setMessages, makeSystemMessageCallback, copilotConfig, setIsLoading, initialMessages, isLoading, actions, onFunctionCall, onCoAgentStateRender, setCoagentStatesWithRef, coagentStatesRef, agentSession, setAgentSession ] ); runChatCompletionRef.current = runChatCompletion; const runChatCompletionAndHandleFunctionCall = useAsyncCallback( (messages2) => __async(this, null, function* () { yield runChatCompletionRef.current(messages2); }), [messages] ); const composeAndFlushMetaEventsInput = useCallback( (metaEvents) => { return metaEvents.reduce((acc, event) => { if (!event) return acc; switch (event.name) { case MetaEventName.LangGraphInterruptEvent: if (event.response) { setLangGraphInterruptAction(null); const value = event.value; return [ ...acc, { name: event.name, value: typeof value === "string" ? value : JSON.stringify(value), response: typeof event.response === "string" ? event.response : JSON.stringify(event.response) } ]; } return acc; default: return acc; } }, []); }, [setLangGraphInterruptAction] ); const append = useAsyncCallback( (message, options2) => __async(this, null, function* () { var _a; if (isLoading) { return; } const newMessages = [...messages, message]; setMessages(newMessages); const followUp = (_a = options2 == null ? void 0 : options2.followUp) != null ? _a : true; if (followUp) { return runChatCompletionAndHandleFunctionCall(newMessages); } }), [isLoading, messages, setMessages, runChatCompletionAndHandleFunctionCall] ); const reload = useAsyncCallback(() => __async(this, null, function* () { if (isLoading || messages.length === 0) { return; } let newMessages = [...messages]; const lastMessage = messages[messages.length - 1]; if (lastMessage.isTextMessage() && lastMessage.role === "assistant") { newMessages = newMessages.slice(0, -1); } setMessages(newMessages); return runChatCompletionAndHandleFunctionCall(newMessages); }), [isLoading, messages, setMessages, runChatCompletionAndHandleFunctionCall]); const stop = () => { var _a; (_a = chatAbortControllerRef.current) == null ? void 0 : _a.abort("Stop was called"); }; return { append, reload, stop, runChatCompletion: () => runChatCompletionRef.current(messages) }; } function constructFinalMessages(syncedMessages, previousMessages, newMessages) { const finalMessages = syncedMessages.length > 0 ? [...syncedMessages] : [...previousMessages, ...newMessages]; if (syncedMessages.length > 0) { const messagesWithAgentState = [...previousMessages, ...newMessages]; let previousMessageId = void 0; for (const message of messagesWithAgentState) { if (message.isAgentStateMessage()) { const index = finalMessages.findIndex((msg) => msg.id === previousMessageId); if (index !== -1) { finalMessages.splice(index + 1, 0, message); } } previousMessageId = message.id; } } return finalMessages; } function executeAction(_0) { return __async(this, arguments, function* ({ onFunctionCall, previousMessages, message, chatAbortControllerRef, onError }) { let result; let error = null; try { result = yield Promise.race([ onFunctionCall({ messages: previousMessages, name: message.name, args: message.arguments }), new Promise( (resolve) => { var _a; return (_a = chatAbortControllerRef.current) == null ? void 0 : _a.signal.addEventListener( "abort", () => resolve("Operation was aborted by the user") ); } ), // if the user stopped generation, we also abort consecutive actions new Promise((resolve) => { var _a; if ((_a = chatAbortControllerRef.current) == null ? void 0 : _a.signal.aborted) { resolve("Operation was aborted by the user"); } }) ]); } catch (e) { onError(e); } return new ResultMessage({ id: "result-" + message.id, result: ResultMessage.encodeResult( error ? { content: result, error: JSON.parse(JSON.stringify(error, Object.getOwnPropertyNames(error))) } : result ), actionExecutionId: message.id, actionName: message.name }); }); } function getPairedFeAction(actions, message) { let actionName = null; if (message.isActionExecutionMessage()) { actionName = message.name; } else if (message.isResultMessage()) { actionName = message.actionName; } return actions.find( (action) => action.name === actionName && action.available === "frontend" || action.pairedAction === actionName ); } export { useChat }; //# sourceMappingURL=chunk-FUO5LKSJ.mjs.map