UNPKG

@copilotkit/runtime

Version:

<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />

711 lines (627 loc) 24.2 kB
/** * <Callout type="info"> * This is the reference for the `CopilotRuntime` class. For more information and example code snippets, please see [Concept: Copilot Runtime](/concepts/copilot-runtime). * </Callout> * * ## Usage * * ```tsx * import { CopilotRuntime } from "@copilotkit/runtime"; * * const copilotKit = new CopilotRuntime(); * ``` */ import { type Action, type CopilotErrorHandler, CopilotKitMisuseError, type MaybePromise, type NonEmptyRecord, type Parameter, readBody, getZodParameters, type PartialBy, isTelemetryDisabled, } from "@copilotkit/shared"; import type { RunAgentInput } from "@ag-ui/core"; import { aguiToGQL } from "../../graphql/message-conversion/agui-to-gql"; import type { CopilotServiceAdapter, RemoteChainParameters } from "../../service-adapters"; import { CopilotRuntime as CopilotRuntimeVNext, type CopilotRuntimeOptions, type CopilotRuntimeOptions as CopilotRuntimeOptionsVNext, InMemoryAgentRunner, } from "@copilotkitnext/runtime"; import { TelemetryAgentRunner } from "./telemetry-agent-runner"; import telemetry from "../telemetry-client"; import type { MessageInput } from "../../graphql/inputs/message.input"; import type { Message } from "../../graphql/types/converted"; import { EndpointType, type EndpointDefinition, type CopilotKitEndpoint, type LangGraphPlatformEndpoint, } from "./types"; import type { CopilotObservabilityConfig, LLMRequestData, LLMResponseData } from "../observability"; import type { AbstractAgent } from "@ag-ui/client"; // +++ MCP Imports +++ import { type MCPClient, type MCPEndpointConfig, type MCPTool, extractParametersFromSchema, } from "./mcp-tools-utils"; import { BuiltInAgent, type BuiltInAgentConfiguration } from "@copilotkitnext/agent"; // Define the function type alias here or import if defined elsewhere type CreateMCPClientFunction = (config: MCPEndpointConfig) => Promise<MCPClient>; type ActionsConfiguration<T extends Parameter[] | [] = []> = | Action<T>[] | ((ctx: { properties: any; url?: string }) => Action<T>[]); interface OnBeforeRequestOptions { threadId?: string; runId?: string; inputMessages: Message[]; properties: any; url?: string; } type OnBeforeRequestHandler = (options: OnBeforeRequestOptions) => void | Promise<void>; interface OnAfterRequestOptions { threadId: string; runId?: string; inputMessages: Message[]; outputMessages: Message[]; properties: any; url?: string; } type OnAfterRequestHandler = (options: OnAfterRequestOptions) => void | Promise<void>; interface OnStopGenerationOptions { threadId: string; runId?: string; url?: string; agentName?: string; lastMessage: MessageInput; } type OnStopGenerationHandler = (options: OnStopGenerationOptions) => void | Promise<void>; interface Middleware { /** * A function that is called before the request is processed. */ /** * @deprecated This middleware hook is deprecated and will be removed in a future version. * Use updated middleware integration methods in CopilotRuntimeVNext instead. */ onBeforeRequest?: OnBeforeRequestHandler; /** * A function that is called after the request is processed. */ /** * @deprecated This middleware hook is deprecated and will be removed in a future version. * Use updated middleware integration methods in CopilotRuntimeVNext instead. */ onAfterRequest?: OnAfterRequestHandler; } export interface CopilotRuntimeConstructorParams_BASE<T extends Parameter[] | [] = []> { /** * Middleware to be used by the runtime. * * ```ts * onBeforeRequest: (options: { * threadId?: string; * runId?: string; * inputMessages: Message[]; * properties: any; * }) => void | Promise<void>; * ``` * * ```ts * onAfterRequest: (options: { * threadId?: string; * runId?: string; * inputMessages: Message[]; * outputMessages: Message[]; * properties: any; * }) => void | Promise<void>; * ``` */ /** * @deprecated This middleware hook is deprecated and will be removed in a future version. * Use updated middleware integration methods in CopilotRuntimeVNext instead. */ middleware?: Middleware; /* * A list of server side actions that can be executed. Will be ignored when remoteActions are set */ actions?: ActionsConfiguration<T>; /* * Deprecated: Use `remoteEndpoints`. */ remoteActions?: CopilotKitEndpoint[]; /* * A list of remote actions that can be executed. */ remoteEndpoints?: EndpointDefinition[]; /* * An array of LangServer URLs. */ langserve?: RemoteChainParameters[]; /* * A map of agent names to AGUI agents. * Example agent config: * ```ts * import { AbstractAgent } from "@ag-ui/client"; * // ... * agents: { * "support": new CustomerSupportAgent(), * "technical": new TechnicalAgent() * } * ``` */ agents?: Record<string, AbstractAgent>; /* * Delegates agent state processing to the service adapter. * * When enabled, individual agent state requests will not be processed by the agent itself. * Instead, all processing will be handled by the service adapter. */ delegateAgentProcessingToServiceAdapter?: boolean; /** * Configuration for LLM request/response logging. * Requires publicApiKey from CopilotKit component to be set: * * ```tsx * <CopilotKit publicApiKey="ck_pub_..." /> * ``` * * Example logging config: * ```ts * logging: { * enabled: true, // Enable or disable logging * progressive: true, // Set to false for buffered logging * logger: { * logRequest: (data) => langfuse.trace({ name: "LLM Request", input: data }), * logResponse: (data) => langfuse.trace({ name: "LLM Response", output: data }), * logError: (errorData) => langfuse.trace({ name: "LLM Error", metadata: errorData }), * }, * } * ``` */ observability_c?: CopilotObservabilityConfig; /** * Configuration for connecting to Model Context Protocol (MCP) servers. * Allows fetching and using tools defined on external MCP-compliant servers. * Requires providing the `createMCPClient` function during instantiation. * @experimental */ mcpServers?: MCPEndpointConfig[]; /** * A function that creates an MCP client instance for a given endpoint configuration. * This function is responsible for using the appropriate MCP client library * (e.g., `@copilotkit/runtime`, `ai`) to establish a connection. * Required if `mcpServers` is provided. * * ```typescript * import { experimental_createMCPClient } from "ai"; // Import from vercel ai library * // ... * const runtime = new CopilotRuntime({ * mcpServers: [{ endpoint: "..." }], * async createMCPClient(config) { * return await experimental_createMCPClient({ * transport: { * type: "sse", * url: config.endpoint, * headers: config.apiKey * ? { Authorization: `Bearer ${config.apiKey}` } * : undefined, * }, * }); * } * }); * ``` */ createMCPClient?: CreateMCPClientFunction; /** * Optional error handler for comprehensive debugging and observability. * * **Requires publicApiKey**: Error handling only works when requests include a valid publicApiKey. * This is a premium Copilot Cloud feature. * * @param errorEvent - Structured error event with rich debugging context * * @example * ```typescript * const runtime = new CopilotRuntime({ * onError: (errorEvent) => { * debugDashboard.capture(errorEvent); * } * }); * ``` */ onError?: CopilotErrorHandler; onStopGeneration?: OnStopGenerationHandler; // /** Optional transcription service for audio processing. */ // transcriptionService?: CopilotRuntimeOptionsVNext["transcriptionService"]; // /** Optional *before* middleware – callback function or webhook URL. */ // beforeRequestMiddleware?: CopilotRuntimeOptionsVNext["beforeRequestMiddleware"]; // /** Optional *after* middleware – callback function or webhook URL. */ // afterRequestMiddleware?: CopilotRuntimeOptionsVNext["afterRequestMiddleware"]; } type BeforeRequestMiddleware = CopilotRuntimeOptionsVNext["beforeRequestMiddleware"]; type AfterRequestMiddleware = CopilotRuntimeOptionsVNext["afterRequestMiddleware"]; type BeforeRequestMiddlewareFn = Exclude<BeforeRequestMiddleware, string>; type BeforeRequestMiddlewareFnParameters = Parameters<BeforeRequestMiddlewareFn>; type BeforeRequestMiddlewareFnResult = ReturnType<BeforeRequestMiddlewareFn>; type AfterRequestMiddlewareFn = Exclude<AfterRequestMiddleware, string>; type AfterRequestMiddlewareFnParameters = Parameters<AfterRequestMiddlewareFn>; interface CopilotRuntimeConstructorParams<T extends Parameter[] | [] = []> extends Omit<CopilotRuntimeConstructorParams_BASE<T>, "agents">, Omit<CopilotRuntimeOptionsVNext, "agents" | "transcriptionService"> { /** * TODO: un-omit `transcriptionService` above once it's supported * * This satisfies... * – the optional constraint in `CopilotRuntimeConstructorParams_BASE` * – the `MaybePromise<NonEmptyRecord<T>>` constraint in `CopilotRuntimeOptionsVNext` * – the `Record<string, AbstractAgent>` constraint in `both */ agents?: MaybePromise<NonEmptyRecord<Record<string, AbstractAgent>>>; } /** * Central runtime object passed to all request handlers. */ export class CopilotRuntime<const T extends Parameter[] | [] = []> { params?: CopilotRuntimeConstructorParams<T>; private observability?: CopilotObservabilityConfig; // Cache MCP tools per endpoint to avoid re-fetching repeatedly private mcpToolsCache: Map<string, BuiltInAgentConfiguration["tools"]> = new Map(); private runtimeArgs: CopilotRuntimeOptions; private _instance: CopilotRuntimeVNext; constructor( params?: CopilotRuntimeConstructorParams<T> & PartialBy<CopilotRuntimeOptions, "agents">, ) { const agents = params?.agents ?? {}; const endpointAgents = this.assignEndpointsToAgents(params?.remoteEndpoints ?? []); // Determine the base runner (user-provided or default) const baseRunner = params?.runner ?? new InMemoryAgentRunner(); // Wrap with TelemetryAgentRunner unless telemetry is disabled // This ensures we always capture agent execution telemetry when enabled, // even if the user provides their own custom runner const runner = isTelemetryDisabled() ? baseRunner : new TelemetryAgentRunner({ runner: baseRunner }); this.runtimeArgs = { agents: { ...endpointAgents, ...agents }, runner, // TODO: add support for transcriptionService from CopilotRuntimeOptionsVNext once it is ready // transcriptionService: params?.transcriptionService, beforeRequestMiddleware: this.createOnBeforeRequestHandler(params).bind(this), afterRequestMiddleware: this.createOnAfterRequestHandler(params).bind(this), }; this.params = params; this.observability = params?.observability_c; } get instance() { if (!this._instance) { this._instance = new CopilotRuntimeVNext(this.runtimeArgs); } return this._instance; } private assignEndpointsToAgents( endpoints: CopilotRuntimeConstructorParams<T>["remoteEndpoints"], ): Record<string, AbstractAgent> { let result: Record<string, AbstractAgent> = {}; if ( endpoints.some((endpoint) => resolveEndpointType(endpoint) == EndpointType.LangGraphPlatform) ) { throw new CopilotKitMisuseError({ message: "LangGraphPlatformEndpoint in remoteEndpoints is deprecated. " + 'Please use the "agents" option instead with LangGraphAgent from "@copilotkit/runtime/langgraph". ' + 'Example: agents: { myAgent: new LangGraphAgent({ deploymentUrl: "...", graphId: "..." }) }', }); } return result; } handleServiceAdapter(serviceAdapter: CopilotServiceAdapter) { this.runtimeArgs.agents = Promise.resolve(this.runtimeArgs.agents ?? {}).then( async (agents) => { let agentsList = agents; const isAgentsListEmpty = !Object.keys(agents).length; const hasServiceAdapter = Boolean(serviceAdapter); const illegalServiceAdapterNames = ["EmptyAdapter"]; const serviceAdapterCanBeUsedForAgent = !illegalServiceAdapterNames.includes( serviceAdapter.name, ); if (isAgentsListEmpty && (!hasServiceAdapter || !serviceAdapterCanBeUsedForAgent)) { throw new CopilotKitMisuseError({ message: "No default agent provided. Please provide a default agent in the runtime config.", }); } if (isAgentsListEmpty) { agentsList.default = new BuiltInAgent({ model: `${serviceAdapter.provider}/${serviceAdapter.model}`, }); } const actions = this.params?.actions; if (actions) { const mcpTools = await this.getToolsFromMCP(); agentsList = this.assignToolsToAgents(agents, [ ...this.getToolsFromActions(actions), ...mcpTools, ]); } return agentsList; }, ); } // Receive this.params.action and turn it into the AbstractAgent tools private getToolsFromActions( actions: ActionsConfiguration<any>, ): BuiltInAgentConfiguration["tools"] { // Resolve actions to an array (handle function case) const actionsArray = typeof actions === "function" ? actions({ properties: {}, url: undefined }) : actions; // Convert each Action to a ToolDefinition return actionsArray.map((action) => { // Convert JSON schema to Zod schema const zodSchema = getZodParameters(action.parameters || []); return { name: action.name, description: action.description || "", parameters: zodSchema, execute: () => Promise.resolve(), }; }); } private assignToolsToAgents( agents: Record<string, AbstractAgent>, tools: BuiltInAgentConfiguration["tools"], ): Record<string, AbstractAgent> { if (!tools?.length) { return agents; } const enrichedAgents: Record<string, AbstractAgent> = { ...agents }; for (const [agentId, agent] of Object.entries(enrichedAgents)) { const existingConfig = (Reflect.get(agent, "config") ?? {}) as BuiltInAgentConfiguration; const existingTools = existingConfig.tools ?? []; const updatedConfig: BuiltInAgentConfiguration = { ...existingConfig, tools: [...existingTools, ...tools], }; Reflect.set(agent, "config", updatedConfig); enrichedAgents[agentId] = agent; } return enrichedAgents; } private createOnBeforeRequestHandler( params?: CopilotRuntimeConstructorParams<T> & PartialBy<CopilotRuntimeOptions, "agents">, ) { return async (hookParams: BeforeRequestMiddlewareFnParameters[0]) => { const { request } = hookParams; // Capture telemetry for copilot request creation const publicApiKey = request.headers.get("x-copilotcloud-public-api-key"); const body = (await readBody(request)) as RunAgentInput; const forwardedProps = body.forwardedProps as | { cloud?: { guardrails?: unknown }; metadata?: { requestType?: string }; } | undefined; // Get cloud base URL from environment or default const cloudBaseUrl = process.env.COPILOT_CLOUD_BASE_URL || "https://api.cloud.copilotkit.ai"; telemetry.capture("oss.runtime.copilot_request_created", { "cloud.guardrails.enabled": forwardedProps?.cloud?.guardrails !== undefined, requestType: forwardedProps?.metadata?.requestType ?? "unknown", "cloud.api_key_provided": !!publicApiKey, ...(publicApiKey ? { "cloud.public_api_key": publicApiKey } : {}), "cloud.base_url": cloudBaseUrl, }); // TODO: get public api key and run with expected data // if (this.observability?.enabled && this.params.publicApiKey) { // this.logObservabilityBeforeRequest() // } // TODO: replace hooksParams top argument type with BeforeRequestMiddlewareParameters when exported params?.beforeRequestMiddleware?.(hookParams); if (params?.middleware?.onBeforeRequest) { const { request, runtime, path } = hookParams; const gqlMessages = (aguiToGQL(body.messages) as Message[]).reduce( (acc, msg) => { if ("role" in msg && msg.role === "user") { acc.inputMessages.push(msg); } else { acc.outputMessages.push(msg); } return acc; }, { inputMessages: [] as Message[], outputMessages: [] as Message[] }, ); const { inputMessages, outputMessages } = gqlMessages; params.middleware.onBeforeRequest({ threadId: body.threadId, runId: body.runId, inputMessages, properties: body.forwardedProps, url: request.url, } satisfies OnBeforeRequestOptions); } }; } private createOnAfterRequestHandler( params?: CopilotRuntimeConstructorParams<T> & PartialBy<CopilotRuntimeOptions, "agents">, ) { return async (hookParams: AfterRequestMiddlewareFnParameters[0]) => { // TODO: get public api key and run with expected data // if (this.observability?.enabled && publicApiKey) { // this.logObservabilityAfterRequest() // } // TODO: replace hooksParams top argument type with AfterRequestMiddlewareParameters when exported params?.afterRequestMiddleware?.(hookParams); if (params?.middleware?.onAfterRequest) { // TODO: provide old expected params here when available // @ts-expect-error -- missing arguments. params.middleware.onAfterRequest({}); } }; } // Observability Methods /** * Log LLM request if observability is enabled */ private async logObservabilityBeforeRequest(requestData: LLMRequestData): Promise<void> { try { await this.observability.hooks.handleRequest(requestData); } catch (error) { console.error("Error logging LLM request:", error); } } /** * Log final LLM response after request completes */ private logObservabilityAfterRequest( outputMessagesPromise: Promise<Message[]>, baseData: { threadId: string; runId?: string; model?: string; provider?: string; agentName?: string; nodeName?: string; }, streamedChunks: any[], requestStartTime: number, publicApiKey?: string, ): void { try { outputMessagesPromise .then((outputMessages) => { const responseData: LLMResponseData = { threadId: baseData.threadId, runId: baseData.runId, model: baseData.model, // Use collected chunks for progressive mode or outputMessages for regular mode output: this.observability.progressive ? streamedChunks : outputMessages, latency: Date.now() - requestStartTime, timestamp: Date.now(), provider: baseData.provider, isFinalResponse: true, agentName: baseData.agentName, nodeName: baseData.nodeName, }; try { this.observability.hooks.handleResponse(responseData); } catch (logError) { console.error("Error logging LLM response:", logError); } }) .catch((error) => { console.error("Failed to get output messages for logging:", error); }); } catch (error) { console.error("Error setting up logging for LLM response:", error); } } // Resolve MCP tools to BuiltInAgent tool definitions // Optionally accepts request-scoped properties to merge request-provided mcpServers private async getToolsFromMCP(options?: { properties?: Record<string, unknown>; }): Promise<BuiltInAgentConfiguration["tools"]> { const runtimeMcpServers = (this.params?.mcpServers ?? []) as MCPEndpointConfig[]; const createMCPClient = this.params?.createMCPClient as CreateMCPClientFunction | undefined; // If no runtime config and no request overrides, nothing to do const requestMcpServers = (( options?.properties as { mcpServers?: MCPEndpointConfig[] } | undefined )?.mcpServers ?? (options?.properties as { mcpEndpoints?: MCPEndpointConfig[] } | undefined)?.mcpEndpoints ?? []) as MCPEndpointConfig[]; const hasAnyServers = (runtimeMcpServers?.length ?? 0) > 0 || (requestMcpServers?.length ?? 0) > 0; if (!hasAnyServers) { return []; } if (!createMCPClient) { // Mirror legacy behavior: when servers are provided without a factory, treat as misconfiguration throw new CopilotKitMisuseError({ message: "MCP Integration Error: `mcpServers` were provided, but the `createMCPClient` function was not passed to the CopilotRuntime constructor. Please provide an implementation for `createMCPClient`.", }); } // Merge and dedupe endpoints by URL; request-level overrides take precedence const effectiveEndpoints = (() => { const byUrl = new Map<string, MCPEndpointConfig>(); for (const ep of runtimeMcpServers) { if (ep?.endpoint) byUrl.set(ep.endpoint, ep); } for (const ep of requestMcpServers) { if (ep?.endpoint) byUrl.set(ep.endpoint, ep); } return Array.from(byUrl.values()); })(); const allTools: BuiltInAgentConfiguration["tools"] = []; for (const config of effectiveEndpoints) { const endpointUrl = config.endpoint; // Return cached tool definitions when available const cached = this.mcpToolsCache.get(endpointUrl); if (cached) { allTools.push(...cached); continue; } try { const client = await createMCPClient(config); const toolsMap = await client.tools(); const toolDefs: BuiltInAgentConfiguration["tools"] = Object.entries(toolsMap).map( ([toolName, tool]: [string, MCPTool]) => { const params: Parameter[] = extractParametersFromSchema(tool); const zodSchema = getZodParameters(params); return { name: toolName, description: tool.description || `MCP tool: ${toolName} (from ${endpointUrl})`, parameters: zodSchema, execute: () => Promise.resolve(), }; }, ); // Cache per endpoint and add to aggregate this.mcpToolsCache.set(endpointUrl, toolDefs); allTools.push(...toolDefs); } catch (error) { console.error( `MCP: Failed to fetch tools from endpoint ${endpointUrl}. Skipping. Error:`, error, ); // Cache empty to prevent repeated attempts within lifecycle this.mcpToolsCache.set(endpointUrl, []); } } // Dedupe tools by name while preserving last-in wins (request overrides) const dedupedByName = new Map<string, (typeof allTools)[number]>(); for (const tool of allTools) { dedupedByName.set(tool.name, tool); } return Array.from(dedupedByName.values()); } } // The two functions below are "factory functions", meant to create the action objects that adhere to the expected interfaces export function copilotKitEndpoint(config: Omit<CopilotKitEndpoint, "type">): CopilotKitEndpoint { return { ...config, type: EndpointType.CopilotKit, }; } export function langGraphPlatformEndpoint( config: Omit<LangGraphPlatformEndpoint, "type">, ): LangGraphPlatformEndpoint { return { ...config, type: EndpointType.LangGraphPlatform, }; } export function resolveEndpointType(endpoint: EndpointDefinition) { if (!endpoint.type) { if ("deploymentUrl" in endpoint && "agents" in endpoint) { return EndpointType.LangGraphPlatform; } else { return EndpointType.CopilotKit; } } return endpoint.type; }