@copilotkit/runtime
Version:
<div align="center"> <a href="https://copilotkit.ai" target="_blank"> <img src="https://github.com/copilotkit/copilotkit/raw/main/assets/banner.png" alt="CopilotKit Logo"> </a>
1,259 lines (1,127 loc) • 43.6 kB
text/typescript
/**
* <Callout type="info">
* This is the reference for the `CopilotRuntime` class. For more information and example code snippets, please see [Concept: Copilot Runtime](/concepts/copilot-runtime).
* </Callout>
*
* ## Usage
*
* ```tsx
* import { CopilotRuntime } from "@copilotkit/runtime";
*
* const copilotKit = new CopilotRuntime();
* ```
*/
import {
Action,
actionParametersToJsonSchema,
Parameter,
ResolvedCopilotKitError,
CopilotKitApiDiscoveryError,
randomId,
CopilotKitError,
CopilotKitLowLevelError,
CopilotKitAgentDiscoveryError,
CopilotKitMisuseError,
} from "@copilotkit/shared";
import {
CopilotServiceAdapter,
EmptyAdapter,
RemoteChain,
RemoteChainParameters,
} from "../../service-adapters";
import { MessageInput } from "../../graphql/inputs/message.input";
import { ActionInput } from "../../graphql/inputs/action.input";
import { RuntimeEventSource, RuntimeEventTypes } from "../../service-adapters/events";
import { convertGqlInputToMessages } from "../../service-adapters/conversion";
import { Message } from "../../graphql/types/converted";
import { ForwardedParametersInput } from "../../graphql/inputs/forwarded-parameters.input";
import {
isRemoteAgentAction,
RemoteAgentAction,
EndpointType,
setupRemoteActions,
EndpointDefinition,
CopilotKitEndpoint,
LangGraphPlatformEndpoint,
} from "./remote-actions";
import { GraphQLContext } from "../integrations/shared";
import { AgentSessionInput } from "../../graphql/inputs/agent-session.input";
import { from } from "rxjs";
import { AgentStateInput } from "../../graphql/inputs/agent-state.input";
import { ActionInputAvailability } from "../../graphql/types/enums";
import { createHeaders } from "./remote-action-constructors";
import { Agent } from "../../graphql/types/agents-response.type";
import { ExtensionsInput } from "../../graphql/inputs/extensions.input";
import { ExtensionsResponse } from "../../graphql/types/extensions-response.type";
import { LoadAgentStateResponse } from "../../graphql/types/load-agent-state-response.type";
import { Client as LangGraphClient } from "@langchain/langgraph-sdk";
import { langchainMessagesToCopilotKit } from "./remote-lg-action";
import { MetaEventInput } from "../../graphql/inputs/meta-event.input";
import {
CopilotObservabilityConfig,
LLMRequestData,
LLMResponseData,
LLMErrorData,
} from "../observability";
import { AbstractAgent } from "@ag-ui/client";
import { MessageRole } from "../../graphql/types/enums";
// +++ MCP Imports +++
import {
MCPClient,
MCPEndpointConfig,
MCPTool,
convertMCPToolsToActions,
generateMcpToolInstructions,
} from "./mcp-tools-utils";
// Define the function type alias here or import if defined elsewhere
type CreateMCPClientFunction = (config: MCPEndpointConfig) => Promise<MCPClient>;
// --- MCP Imports ---
export interface CopilotRuntimeRequest {
serviceAdapter: CopilotServiceAdapter;
messages: MessageInput[];
actions: ActionInput[];
agentSession?: AgentSessionInput;
agentStates?: AgentStateInput[];
outputMessagesPromise: Promise<Message[]>;
threadId?: string;
runId?: string;
publicApiKey?: string;
graphqlContext: GraphQLContext;
forwardedParameters?: ForwardedParametersInput;
url?: string;
extensions?: ExtensionsInput;
metaEvents?: MetaEventInput[];
}
interface CopilotRuntimeResponse {
threadId: string;
runId?: string;
eventSource: RuntimeEventSource;
serverSideActions: Action<any>[];
actionInputsWithoutAgents: ActionInput[];
extensions?: ExtensionsResponse;
}
type ActionsConfiguration<T extends Parameter[] | [] = []> =
| Action<T>[]
| ((ctx: { properties: any; url?: string }) => Action<T>[]);
interface OnBeforeRequestOptions {
threadId?: string;
runId?: string;
inputMessages: Message[];
properties: any;
url?: string;
}
type OnBeforeRequestHandler = (options: OnBeforeRequestOptions) => void | Promise<void>;
interface OnAfterRequestOptions {
threadId: string;
runId?: string;
inputMessages: Message[];
outputMessages: Message[];
properties: any;
url?: string;
}
type OnAfterRequestHandler = (options: OnAfterRequestOptions) => void | Promise<void>;
interface Middleware {
/**
* A function that is called before the request is processed.
*/
onBeforeRequest?: OnBeforeRequestHandler;
/**
* A function that is called after the request is processed.
*/
onAfterRequest?: OnAfterRequestHandler;
}
type AgentWithEndpoint = Agent & { endpoint: EndpointDefinition };
export interface CopilotRuntimeConstructorParams<T extends Parameter[] | [] = []> {
/**
* Middleware to be used by the runtime.
*
* ```ts
* onBeforeRequest: (options: {
* threadId?: string;
* runId?: string;
* inputMessages: Message[];
* properties: any;
* }) => void | Promise<void>;
* ```
*
* ```ts
* onAfterRequest: (options: {
* threadId?: string;
* runId?: string;
* inputMessages: Message[];
* outputMessages: Message[];
* properties: any;
* }) => void | Promise<void>;
* ```
*/
middleware?: Middleware;
/*
* A list of server side actions that can be executed. Will be ignored when remoteActions are set
*/
actions?: ActionsConfiguration<T>;
/*
* Deprecated: Use `remoteEndpoints`.
*/
remoteActions?: CopilotKitEndpoint[];
/*
* A list of remote actions that can be executed.
*/
remoteEndpoints?: EndpointDefinition[];
/*
* An array of LangServer URLs.
*/
langserve?: RemoteChainParameters[];
/*
* A map of agent names to AGUI agents.
* Example agent config:
* ```ts
* import { AbstractAgent } from "@ag-ui/client";
* // ...
* agents: {
* "support": new CustomerSupportAgent(),
* "technical": new TechnicalAgent()
* }
* ```
*/
agents?: Record<string, AbstractAgent>;
/*
* Delegates agent state processing to the service adapter.
*
* When enabled, individual agent state requests will not be processed by the agent itself.
* Instead, all processing will be handled by the service adapter.
*/
delegateAgentProcessingToServiceAdapter?: boolean;
/**
* Configuration for LLM request/response logging.
* Requires publicApiKey from CopilotKit component to be set:
*
* ```tsx
* <CopilotKit publicApiKey="ck_pub_..." />
* ```
*
* Example logging config:
* ```ts
* logging: {
* enabled: true, // Enable or disable logging
* progressive: true, // Set to false for buffered logging
* logger: {
* logRequest: (data) => langfuse.trace({ name: "LLM Request", input: data }),
* logResponse: (data) => langfuse.trace({ name: "LLM Response", output: data }),
* logError: (errorData) => langfuse.trace({ name: "LLM Error", metadata: errorData }),
* },
* }
* ```
*/
observability_c?: CopilotObservabilityConfig;
/**
* Configuration for connecting to Model Context Protocol (MCP) servers.
* Allows fetching and using tools defined on external MCP-compliant servers.
* Requires providing the `createMCPClient` function during instantiation.
* @experimental
*/
mcpServers?: MCPEndpointConfig[];
/**
* A function that creates an MCP client instance for a given endpoint configuration.
* This function is responsible for using the appropriate MCP client library
* (e.g., `@copilotkit/runtime`, `ai`) to establish a connection.
* Required if `mcpServers` is provided.
*
* ```typescript
* import { experimental_createMCPClient } from "ai"; // Import from vercel ai library
* // ...
* const runtime = new CopilotRuntime({
* mcpServers: [{ endpoint: "..." }],
* async createMCPClient(config) {
* return await experimental_createMCPClient({
* transport: {
* type: "sse",
* url: config.endpoint,
* headers: config.apiKey
* ? { Authorization: `Bearer ${config.apiKey}` }
* : undefined,
* },
* });
* }
* });
* ```
*/
createMCPClient?: CreateMCPClientFunction;
}
export class CopilotRuntime<const T extends Parameter[] | [] = []> {
public actions: ActionsConfiguration<T>;
public agents: Record<string, AbstractAgent>;
public remoteEndpointDefinitions: EndpointDefinition[];
private langserve: Promise<Action<any>>[] = [];
private onBeforeRequest?: OnBeforeRequestHandler;
private onAfterRequest?: OnAfterRequestHandler;
private delegateAgentProcessingToServiceAdapter: boolean;
private observability?: CopilotObservabilityConfig;
private availableAgents: Pick<AgentWithEndpoint, "name" | "id">[];
// +++ MCP Properties +++
private readonly mcpServersConfig?: MCPEndpointConfig[];
private mcpActionCache = new Map<string, Action<any>[]>();
// --- MCP Properties ---
// +++ MCP Client Factory +++
private readonly createMCPClientImpl?: CreateMCPClientFunction;
// --- MCP Client Factory ---
constructor(params?: CopilotRuntimeConstructorParams<T>) {
if (
params?.actions &&
params?.remoteEndpoints &&
params?.remoteEndpoints.some((e) => e.type === EndpointType.LangGraphPlatform)
) {
console.warn("Actions set in runtime instance will not be available for the agent");
}
this.actions = params?.actions || [];
this.availableAgents = [];
for (const chain of params?.langserve || []) {
const remoteChain = new RemoteChain(chain);
this.langserve.push(remoteChain.toAction());
}
this.remoteEndpointDefinitions = params?.remoteEndpoints ?? params?.remoteActions ?? [];
this.onBeforeRequest = params?.middleware?.onBeforeRequest;
this.onAfterRequest = params?.middleware?.onAfterRequest;
this.delegateAgentProcessingToServiceAdapter =
params?.delegateAgentProcessingToServiceAdapter || false;
this.observability = params?.observability_c;
this.agents = params?.agents ?? {};
// +++ MCP Initialization +++
this.mcpServersConfig = params?.mcpServers;
this.createMCPClientImpl = params?.createMCPClient;
// Validate: If mcpServers are provided, createMCPClient must also be provided
if (this.mcpServersConfig && this.mcpServersConfig.length > 0 && !this.createMCPClientImpl) {
throw new CopilotKitMisuseError({
message:
"MCP Integration Error: `mcpServers` were provided, but the `createMCPClient` function was not passed to the CopilotRuntime constructor. " +
"Please provide an implementation for `createMCPClient`.",
});
}
// Warning if actions are defined alongside LangGraph platform (potentially MCP too?)
if (
params?.actions &&
(params?.remoteEndpoints?.some((e) => e.type === EndpointType.LangGraphPlatform) ||
this.mcpServersConfig?.length)
) {
console.warn(
"Local 'actions' defined in CopilotRuntime might not be available to remote agents (LangGraph, MCP). Consider defining actions closer to the agent implementation if needed.",
);
}
}
// +++ MCP Instruction Injection Method +++
private injectMCPToolInstructions(
messages: MessageInput[],
currentActions: Action<any>[],
): MessageInput[] {
// Filter the *passed-in* actions for MCP tools
const mcpActionsForRequest = currentActions.filter((action) => (action as any)._isMCPTool);
if (!mcpActionsForRequest || mcpActionsForRequest.length === 0) {
return messages; // No MCP tools for this specific request
}
// Create a map to deduplicate tools by name (keeping the last one if duplicates exist)
const uniqueMcpTools = new Map<string, Action<any>>();
// Add all MCP tools to the map with their names as keys
mcpActionsForRequest.forEach((action) => {
uniqueMcpTools.set(action.name, action);
});
// Format instructions from the unique tools map
// Convert Action objects to MCPTool format for the instruction generator
const toolsMap: Record<string, MCPTool> = {};
Array.from(uniqueMcpTools.values()).forEach((action) => {
toolsMap[action.name] = {
description: action.description || "",
schema: action.parameters
? {
parameters: {
properties: action.parameters.reduce(
(acc, p) => ({
...acc,
[p.name]: { type: p.type, description: p.description },
}),
{},
),
required: action.parameters.filter((p) => p.required).map((p) => p.name),
},
}
: {},
execute: async () => ({}), // Placeholder, not used for instructions
};
});
// Generate instructions using the exported helper
const mcpToolInstructions = generateMcpToolInstructions(toolsMap);
if (!mcpToolInstructions) {
return messages; // No MCP tools to describe
}
const instructions =
"You have access to the following tools provided by external Model Context Protocol (MCP) servers:\n" +
mcpToolInstructions +
"\nUse them when appropriate to fulfill the user's request.";
const systemMessageIndex = messages.findIndex((msg) => msg.textMessage?.role === "system");
const newMessages = [...messages]; // Create a mutable copy
if (systemMessageIndex !== -1) {
const existingMsg = newMessages[systemMessageIndex];
if (existingMsg.textMessage) {
existingMsg.textMessage.content =
(existingMsg.textMessage.content ? existingMsg.textMessage.content + "\n\n" : "") +
instructions;
}
} else {
newMessages.unshift({
id: randomId(),
createdAt: new Date(),
textMessage: {
role: MessageRole.system,
content: instructions,
},
actionExecutionMessage: undefined,
resultMessage: undefined,
agentStateMessage: undefined,
});
}
return newMessages;
}
async processRuntimeRequest(request: CopilotRuntimeRequest): Promise<CopilotRuntimeResponse> {
const {
serviceAdapter,
messages: rawMessages,
actions: clientSideActionsInput,
threadId,
runId,
outputMessagesPromise,
graphqlContext,
forwardedParameters,
url,
extensions,
agentSession,
agentStates,
publicApiKey,
} = request;
const eventSource = new RuntimeEventSource();
// Track request start time for logging
const requestStartTime = Date.now();
// For storing streamed chunks if progressive logging is enabled
const streamedChunks: any[] = [];
try {
if (agentSession && !this.delegateAgentProcessingToServiceAdapter) {
return await this.processAgentRequest(request);
}
if (serviceAdapter instanceof EmptyAdapter) {
throw new CopilotKitMisuseError({
message: `Invalid adapter configuration: EmptyAdapter is only meant to be used with agent lock mode.
For non-agent components like useCopilotChatSuggestions, CopilotTextarea, or CopilotTask,
please use an LLM adapter instead.`,
});
}
// +++ Get Server Side Actions (including dynamic MCP) EARLY +++
const serverSideActions = await this.getServerSideActions(request);
// --- Get Server Side Actions (including dynamic MCP) EARLY ---
// Filter raw messages *before* injection
const filteredRawMessages = rawMessages.filter((message) => !message.agentStateMessage);
// +++ Inject MCP Instructions based on current actions +++
const messagesWithInjectedInstructions = this.injectMCPToolInstructions(
filteredRawMessages,
serverSideActions,
);
const inputMessages = convertGqlInputToMessages(messagesWithInjectedInstructions);
// --- Inject MCP Instructions based on current actions ---
// Log LLM request if logging is enabled
if (this.observability?.enabled && publicApiKey) {
try {
const requestData: LLMRequestData = {
threadId,
runId,
model: forwardedParameters?.model,
messages: inputMessages,
actions: clientSideActionsInput,
forwardedParameters,
timestamp: requestStartTime,
provider: this.detectProvider(serviceAdapter),
};
await this.observability.hooks.handleRequest(requestData);
} catch (error) {
console.error("Error logging LLM request:", error);
}
}
const serverSideActionsInput: ActionInput[] = serverSideActions.map((action) => ({
name: action.name,
description: action.description,
jsonSchema: JSON.stringify(actionParametersToJsonSchema(action.parameters)),
}));
const actionInputs = flattenToolCallsNoDuplicates([
...serverSideActionsInput,
...clientSideActionsInput.filter(
// Filter remote actions from CopilotKit core loop
(action) => action.available !== ActionInputAvailability.remote,
),
]);
await this.onBeforeRequest?.({
threadId,
runId,
inputMessages,
properties: graphqlContext.properties,
url,
});
const result = await serviceAdapter.process({
messages: inputMessages,
actions: actionInputs,
threadId,
runId,
eventSource,
forwardedParameters,
extensions,
agentSession,
agentStates,
});
// for backwards compatibility, we deal with the case that no threadId is provided
// by the frontend, by using the threadId from the response
const nonEmptyThreadId = threadId ?? result.threadId;
outputMessagesPromise
.then((outputMessages) => {
this.onAfterRequest?.({
threadId: nonEmptyThreadId,
runId: result.runId,
inputMessages,
outputMessages,
properties: graphqlContext.properties,
url,
});
})
.catch((_error) => {});
// After getting the response, log it if logging is enabled
if (this.observability?.enabled && publicApiKey) {
try {
outputMessagesPromise
.then((outputMessages) => {
const responseData: LLMResponseData = {
threadId: result.threadId,
runId: result.runId,
model: forwardedParameters?.model,
// Use collected chunks for progressive mode or outputMessages for regular mode
output: this.observability.progressive ? streamedChunks : outputMessages,
latency: Date.now() - requestStartTime,
timestamp: Date.now(),
provider: this.detectProvider(serviceAdapter),
// Indicate this is the final response
isFinalResponse: true,
};
try {
this.observability.hooks.handleResponse(responseData);
} catch (logError) {
console.error("Error logging LLM response:", logError);
}
})
.catch((error) => {
console.error("Failed to get output messages for logging:", error);
});
} catch (error) {
console.error("Error setting up logging for LLM response:", error);
}
}
// Add progressive logging if enabled
if (this.observability?.enabled && this.observability.progressive && publicApiKey) {
// Keep reference to original stream function
const originalStream = eventSource.stream.bind(eventSource);
// Wrap the stream function to intercept events
eventSource.stream = async (callback) => {
await originalStream(async (eventStream$) => {
// Create subscription to capture streaming events
eventStream$.subscribe({
next: (event) => {
// Only log content chunks
if (event.type === RuntimeEventTypes.TextMessageContent) {
// Store the chunk
streamedChunks.push(event.content);
// Log each chunk separately for progressive mode
try {
const progressiveData: LLMResponseData = {
threadId: threadId || "",
runId,
model: forwardedParameters?.model,
output: event.content,
latency: Date.now() - requestStartTime,
timestamp: Date.now(),
provider: this.detectProvider(serviceAdapter),
isProgressiveChunk: true,
};
// Use Promise to handle async logger without awaiting
Promise.resolve()
.then(() => {
this.observability.hooks.handleResponse(progressiveData);
})
.catch((error) => {
console.error("Error in progressive logging:", error);
});
} catch (error) {
console.error("Error preparing progressive log data:", error);
}
}
},
});
// Call the original callback with the event stream
await callback(eventStream$);
});
};
}
return {
threadId: nonEmptyThreadId,
runId: result.runId,
eventSource,
serverSideActions,
actionInputsWithoutAgents: actionInputs.filter(
(action) =>
// TODO-AGENTS: do not exclude ALL server side actions
!serverSideActions.find((serverSideAction) => serverSideAction.name == action.name),
// !isRemoteAgentAction(
// serverSideActions.find((serverSideAction) => serverSideAction.name == action.name),
// ),
),
extensions: result.extensions,
};
} catch (error) {
// Log error if logging is enabled
if (this.observability?.enabled && publicApiKey) {
try {
const errorData: LLMErrorData = {
threadId,
runId,
model: forwardedParameters?.model,
error: error instanceof Error ? error : String(error),
timestamp: Date.now(),
latency: Date.now() - requestStartTime,
provider: this.detectProvider(serviceAdapter),
};
await this.observability.hooks.handleError(errorData);
} catch (logError) {
console.error("Error logging LLM error:", logError);
}
}
if (error instanceof CopilotKitError) {
throw error;
}
console.error("Error getting response:", error);
eventSource.sendErrorMessageToChat();
throw error;
}
}
async discoverAgentsFromEndpoints(graphqlContext: GraphQLContext): Promise<AgentWithEndpoint[]> {
const agents: Promise<AgentWithEndpoint[]> = this.remoteEndpointDefinitions.reduce(
async (acc: Promise<Agent[]>, endpoint) => {
const agents = await acc;
if (endpoint.type === EndpointType.LangGraphPlatform) {
const propertyHeaders = graphqlContext.properties.authorization
? { authorization: `Bearer ${graphqlContext.properties.authorization}` }
: null;
const client = new LangGraphClient({
apiUrl: endpoint.deploymentUrl,
apiKey: endpoint.langsmithApiKey,
defaultHeaders: { ...propertyHeaders },
});
let data: Array<{ assistant_id: string; graph_id: string }> | { detail: string } = [];
try {
data = await client.assistants.search();
if (data && "detail" in data && (data.detail as string).toLowerCase() === "not found") {
throw new CopilotKitAgentDiscoveryError({ availableAgents: this.availableAgents });
}
} catch (e) {
throw new CopilotKitMisuseError({
message: `
Failed to find or contact remote endpoint at url ${endpoint.deploymentUrl}.
Make sure the API is running and that it's indeed a LangGraph platform url.
See more: https://docs.copilotkit.ai/troubleshooting/common-issues`,
});
}
const endpointAgents = data.map((entry) => ({
name: entry.graph_id,
id: entry.assistant_id,
description: "",
endpoint,
}));
return [...agents, ...endpointAgents];
}
interface InfoResponse {
agents?: Array<{
name: string;
description: string;
}>;
}
const cpkEndpoint = endpoint as CopilotKitEndpoint;
const fetchUrl = `${endpoint.url}/info`;
try {
const response = await fetch(fetchUrl, {
method: "POST",
headers: createHeaders(cpkEndpoint.onBeforeRequest, graphqlContext),
body: JSON.stringify({ properties: graphqlContext.properties }),
});
if (!response.ok) {
if (response.status === 404) {
throw new CopilotKitApiDiscoveryError({ url: fetchUrl });
}
throw new ResolvedCopilotKitError({
status: response.status,
url: fetchUrl,
isRemoteEndpoint: true,
});
}
const data: InfoResponse = await response.json();
const endpointAgents = (data?.agents ?? []).map((agent) => ({
name: agent.name,
description: agent.description ?? "" ?? "",
id: randomId(), // Required by Agent type
endpoint,
}));
return [...agents, ...endpointAgents];
} catch (error) {
if (error instanceof CopilotKitError) {
throw error;
}
throw new CopilotKitLowLevelError({ error: error as Error, url: fetchUrl });
}
},
Promise.resolve([]),
);
this.availableAgents = ((await agents) ?? []).map((a) => ({ name: a.name, id: a.id }));
return agents;
}
async loadAgentState(
graphqlContext: GraphQLContext,
threadId: string,
agentName: string,
): Promise<LoadAgentStateResponse> {
const agentsWithEndpoints = await this.discoverAgentsFromEndpoints(graphqlContext);
const agentWithEndpoint = agentsWithEndpoints.find((agent) => agent.name === agentName);
if (!agentWithEndpoint) {
throw new Error("Agent not found");
}
if (agentWithEndpoint.endpoint.type === EndpointType.LangGraphPlatform) {
const propertyHeaders = graphqlContext.properties.authorization
? { authorization: `Bearer ${graphqlContext.properties.authorization}` }
: null;
const client = new LangGraphClient({
apiUrl: agentWithEndpoint.endpoint.deploymentUrl,
apiKey: agentWithEndpoint.endpoint.langsmithApiKey,
defaultHeaders: { ...propertyHeaders },
});
let state: any = {};
try {
state = (await client.threads.getState(threadId)).values as any;
} catch (error) {}
if (Object.keys(state).length === 0) {
return {
threadId: threadId || "",
threadExists: false,
state: JSON.stringify({}),
messages: JSON.stringify([]),
};
} else {
const { messages, ...stateWithoutMessages } = state;
const copilotkitMessages = langchainMessagesToCopilotKit(messages);
return {
threadId: threadId || "",
threadExists: true,
state: JSON.stringify(stateWithoutMessages),
messages: JSON.stringify(copilotkitMessages),
};
}
} else if (
agentWithEndpoint.endpoint.type === EndpointType.CopilotKit ||
!("type" in agentWithEndpoint.endpoint)
) {
const cpkEndpoint = agentWithEndpoint.endpoint as CopilotKitEndpoint;
const fetchUrl = `${cpkEndpoint.url}/agents/state`;
try {
const response = await fetch(fetchUrl, {
method: "POST",
headers: createHeaders(cpkEndpoint.onBeforeRequest, graphqlContext),
body: JSON.stringify({
properties: graphqlContext.properties,
threadId,
name: agentName,
}),
});
if (!response.ok) {
if (response.status === 404) {
throw new CopilotKitApiDiscoveryError({ url: fetchUrl });
}
throw new ResolvedCopilotKitError({
status: response.status,
url: fetchUrl,
isRemoteEndpoint: true,
});
}
const data: LoadAgentStateResponse = await response.json();
return {
...data,
state: JSON.stringify(data.state),
messages: JSON.stringify(data.messages),
};
} catch (error) {
if (error instanceof CopilotKitError) {
throw error;
}
throw new CopilotKitLowLevelError({ error, url: fetchUrl });
}
} else {
throw new Error(`Unknown endpoint type: ${(agentWithEndpoint.endpoint as any).type}`);
}
}
private async processAgentRequest(
request: CopilotRuntimeRequest,
): Promise<CopilotRuntimeResponse> {
const {
messages: rawMessages,
outputMessagesPromise,
graphqlContext,
agentSession,
threadId: threadIdFromRequest,
metaEvents,
publicApiKey,
forwardedParameters,
} = request;
const { agentName, nodeName } = agentSession;
// Track request start time for observability
const requestStartTime = Date.now();
// For storing streamed chunks if progressive logging is enabled
const streamedChunks: any[] = [];
// for backwards compatibility, deal with the case when no threadId is provided
const threadId = threadIdFromRequest ?? agentSession.threadId;
const serverSideActions = await this.getServerSideActions(request);
const messages = convertGqlInputToMessages(rawMessages);
const currentAgent = serverSideActions.find(
(action) => action.name === agentName && isRemoteAgentAction(action),
) as RemoteAgentAction;
if (!currentAgent) {
throw new CopilotKitAgentDiscoveryError({ agentName, availableAgents: this.availableAgents });
}
// Filter actions to include:
// 1. Regular (non-agent) actions
// 2. Other agents' actions (but prevent self-calls to avoid infinite loops)
const availableActionsForCurrentAgent: ActionInput[] = serverSideActions
.filter(
(action) =>
// Case 1: Keep all regular (non-agent) actions
!isRemoteAgentAction(action) ||
// Case 2: For agent actions, keep all except self (prevent infinite loops)
(isRemoteAgentAction(action) && action.name !== agentName) /* prevent self-calls */,
)
.map((action) => ({
name: action.name,
description: action.description,
jsonSchema: JSON.stringify(actionParametersToJsonSchema(action.parameters)),
}));
const allAvailableActions = flattenToolCallsNoDuplicates([
...availableActionsForCurrentAgent,
...request.actions,
]);
// Log agent request if observability is enabled
if (this.observability?.enabled && publicApiKey) {
try {
const requestData: LLMRequestData = {
threadId,
runId: undefined,
model: forwardedParameters?.model,
messages,
actions: allAvailableActions,
forwardedParameters,
timestamp: requestStartTime,
provider: "agent",
agentName, // Add agent-specific context
nodeName,
};
await this.observability.hooks.handleRequest(requestData);
} catch (error) {
console.error("Error logging agent request:", error);
}
}
await this.onBeforeRequest?.({
threadId,
runId: undefined,
inputMessages: messages,
properties: graphqlContext.properties,
});
try {
const eventSource = new RuntimeEventSource();
const stream = await currentAgent.remoteAgentHandler({
name: agentName,
threadId,
nodeName,
metaEvents,
actionInputsWithoutAgents: allAvailableActions,
});
// Add progressive observability if enabled
if (this.observability?.enabled && this.observability.progressive && publicApiKey) {
// Wrap the stream function to intercept events for observability without changing core logic
const originalStream = eventSource.stream.bind(eventSource);
eventSource.stream = async (callback) => {
await originalStream(async (eventStream$) => {
// Create subscription to capture streaming events
eventStream$.subscribe({
next: (event) => {
// Only log content chunks
if (event.type === RuntimeEventTypes.TextMessageContent) {
// Store the chunk
streamedChunks.push(event.content);
// Log each chunk separately for progressive mode
try {
const progressiveData: LLMResponseData = {
threadId: threadId || "",
runId: undefined,
model: forwardedParameters?.model,
output: event.content,
latency: Date.now() - requestStartTime,
timestamp: Date.now(),
provider: "agent",
isProgressiveChunk: true,
agentName,
nodeName,
};
// Use Promise to handle async logger without awaiting
Promise.resolve()
.then(() => {
this.observability.hooks.handleResponse(progressiveData);
})
.catch((error) => {
console.error("Error in progressive agent logging:", error);
});
} catch (error) {
console.error("Error preparing progressive agent log data:", error);
}
}
},
});
// Call the original callback with the event stream
await callback(eventStream$);
});
};
}
eventSource.stream(async (eventStream$) => {
from(stream).subscribe({
next: (event) => eventStream$.next(event),
error: (err) => {
console.error("Error in stream", err);
// Log error with observability if enabled
if (this.observability?.enabled && publicApiKey) {
try {
const errorData: LLMErrorData = {
threadId,
runId: undefined,
model: forwardedParameters?.model,
error: err instanceof Error ? err : String(err),
timestamp: Date.now(),
latency: Date.now() - requestStartTime,
provider: "agent",
agentName,
nodeName,
};
this.observability.hooks.handleError(errorData);
} catch (logError) {
console.error("Error logging agent error:", logError);
}
}
eventStream$.error(err);
eventStream$.complete();
},
complete: () => eventStream$.complete(),
});
});
// Log final agent response when outputs are available
if (this.observability?.enabled && publicApiKey) {
outputMessagesPromise
.then((outputMessages) => {
const responseData: LLMResponseData = {
threadId,
runId: undefined,
model: forwardedParameters?.model,
// Use collected chunks for progressive mode or outputMessages for regular mode
output: this.observability.progressive ? streamedChunks : outputMessages,
latency: Date.now() - requestStartTime,
timestamp: Date.now(),
provider: "agent",
isFinalResponse: true,
agentName,
nodeName,
};
try {
this.observability.hooks.handleResponse(responseData);
} catch (logError) {
console.error("Error logging agent response:", logError);
}
})
.catch((error) => {
console.error("Failed to get output messages for agent logging:", error);
});
}
outputMessagesPromise
.then((outputMessages) => {
this.onAfterRequest?.({
threadId,
runId: undefined,
inputMessages: messages,
outputMessages,
properties: graphqlContext.properties,
});
})
.catch((_error) => {});
return {
threadId,
runId: undefined,
eventSource,
serverSideActions,
actionInputsWithoutAgents: allAvailableActions,
};
} catch (error) {
// Log error with observability if enabled
if (this.observability?.enabled && publicApiKey) {
try {
const errorData: LLMErrorData = {
threadId,
runId: undefined,
model: forwardedParameters?.model,
error: error instanceof Error ? error : String(error),
timestamp: Date.now(),
latency: Date.now() - requestStartTime,
provider: "agent",
agentName,
nodeName,
};
await this.observability.hooks.handleError(errorData);
} catch (logError) {
console.error("Error logging agent error:", logError);
}
}
console.error("Error getting response:", error);
throw error;
}
}
private async getServerSideActions(request: CopilotRuntimeRequest): Promise<Action<any>[]> {
const { graphqlContext, messages: rawMessages, agentStates, url } = request;
// --- Standard Action Fetching (unchanged) ---
const inputMessages = convertGqlInputToMessages(rawMessages);
const langserveFunctions: Action<any>[] = [];
for (const chainPromise of this.langserve) {
try {
const chain = await chainPromise;
langserveFunctions.push(chain);
} catch (error) {
console.error("Error loading langserve chain:", error);
}
}
const remoteEndpointDefinitions = this.remoteEndpointDefinitions.map(
(endpoint) => ({ ...endpoint, type: resolveEndpointType(endpoint) }) as EndpointDefinition,
);
const remoteActions = await setupRemoteActions({
remoteEndpointDefinitions,
graphqlContext,
messages: inputMessages,
agentStates,
frontendUrl: url,
agents: this.agents,
metaEvents: request.metaEvents,
});
const configuredActions =
typeof this.actions === "function"
? this.actions({ properties: graphqlContext.properties, url })
: this.actions;
// --- Standard Action Fetching (unchanged) ---
// +++ Dynamic MCP Action Fetching +++
const requestSpecificMCPActions: Action<any>[] = [];
if (this.createMCPClientImpl) {
// 1. Determine effective MCP endpoints for this request
const baseEndpoints = this.mcpServersConfig || [];
// Assuming frontend passes config via properties.mcpServers
const requestEndpoints = (graphqlContext.properties?.mcpServers ||
graphqlContext.properties?.mcpEndpoints ||
[]) as MCPEndpointConfig[];
// Merge and deduplicate endpoints based on URL
const effectiveEndpointsMap = new Map<string, MCPEndpointConfig>();
// First add base endpoints (from runtime configuration)
[...baseEndpoints].forEach((ep) => {
if (ep && ep.endpoint) {
effectiveEndpointsMap.set(ep.endpoint, ep);
}
});
// Then add request endpoints (from frontend), which will override duplicates
[...requestEndpoints].forEach((ep) => {
if (ep && ep.endpoint) {
effectiveEndpointsMap.set(ep.endpoint, ep);
}
});
const effectiveEndpoints = Array.from(effectiveEndpointsMap.values());
// 2. Fetch/Cache actions for effective endpoints
for (const config of effectiveEndpoints) {
const endpointUrl = config.endpoint;
let actionsForEndpoint: Action<any>[] | undefined = this.mcpActionCache.get(endpointUrl);
if (!actionsForEndpoint) {
// Not cached, fetch now
let client: MCPClient | null = null;
try {
client = await this.createMCPClientImpl(config);
const tools = await client.tools();
actionsForEndpoint = convertMCPToolsToActions(tools, endpointUrl);
this.mcpActionCache.set(endpointUrl, actionsForEndpoint); // Store in cache
} catch (error) {
console.error(
`MCP: Failed to fetch tools from endpoint ${endpointUrl}. Skipping. Error:`,
error,
);
actionsForEndpoint = []; // Assign empty array on error to prevent re-fetching constantly
this.mcpActionCache.set(endpointUrl, actionsForEndpoint); // Cache the failure (empty array)
}
}
requestSpecificMCPActions.push(...(actionsForEndpoint || []));
}
}
// --- Dynamic MCP Action Fetching ---
// Combine all action sources, including the dynamically fetched MCP actions
return [
...configuredActions,
...langserveFunctions,
...remoteActions,
...requestSpecificMCPActions,
];
}
// Add helper method to detect provider
private detectProvider(serviceAdapter: CopilotServiceAdapter): string | undefined {
const adapterName = serviceAdapter.constructor.name;
if (adapterName.includes("OpenAI")) return "openai";
if (adapterName.includes("Anthropic")) return "anthropic";
if (adapterName.includes("Google")) return "google";
if (adapterName.includes("Groq")) return "groq";
if (adapterName.includes("LangChain")) return "langchain";
return undefined;
}
}
export function flattenToolCallsNoDuplicates(toolsByPriority: ActionInput[]): ActionInput[] {
let allTools: ActionInput[] = [];
const allToolNames: string[] = [];
for (const tool of toolsByPriority) {
if (!allToolNames.includes(tool.name)) {
allTools.push(tool);
allToolNames.push(tool.name);
}
}
return allTools;
}
// The two functions below are "factory functions", meant to create the action objects that adhere to the expected interfaces
export function copilotKitEndpoint(config: Omit<CopilotKitEndpoint, "type">): CopilotKitEndpoint {
return {
...config,
type: EndpointType.CopilotKit,
};
}
export function langGraphPlatformEndpoint(
config: Omit<LangGraphPlatformEndpoint, "type">,
): LangGraphPlatformEndpoint {
return {
...config,
type: EndpointType.LangGraphPlatform,
};
}
export function resolveEndpointType(endpoint: EndpointDefinition) {
if (!endpoint.type) {
if ("deploymentUrl" in endpoint && "agents" in endpoint) {
return EndpointType.LangGraphPlatform;
} else {
return EndpointType.CopilotKit;
}
}
return endpoint.type;
}