@copilotkit/runtime
Version: 
<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />
548 lines (503 loc) • 16.3 kB
text/typescript
import {
  Action,
  CopilotKitError,
  CopilotKitErrorCode,
  CopilotKitLowLevelError,
  ensureStructuredError,
  randomId,
  Severity,
} from "@copilotkit/shared";
import { plainToInstance } from "class-transformer";
import {
  catchError,
  concat,
  concatMap,
  EMPTY,
  firstValueFrom,
  from,
  of,
  ReplaySubject,
  scan,
  Subject,
} from "rxjs";
import { ActionInput } from "../graphql/inputs/action.input";
import { ActionExecutionMessage, ResultMessage, TextMessage } from "../graphql/types/converted";
import { GuardrailsResult } from "../graphql/types/guardrails-result.type";
import { isRemoteAgentAction } from "../lib/runtime/remote-actions";
import { generateHelpfulErrorMessage } from "../lib/streaming";
import telemetry from "../lib/telemetry-client";
import { streamLangChainResponse } from "./langchain/utils";
export enum RuntimeEventTypes {
  TextMessageStart = "TextMessageStart",
  TextMessageContent = "TextMessageContent",
  TextMessageEnd = "TextMessageEnd",
  ActionExecutionStart = "ActionExecutionStart",
  ActionExecutionArgs = "ActionExecutionArgs",
  ActionExecutionEnd = "ActionExecutionEnd",
  ActionExecutionResult = "ActionExecutionResult",
  AgentStateMessage = "AgentStateMessage",
  MetaEvent = "MetaEvent",
  RunError = "RunError",
}
export enum RuntimeMetaEventName {
  LangGraphInterruptEvent = "LangGraphInterruptEvent",
  LangGraphInterruptResumeEvent = "LangGraphInterruptResumeEvent",
  CopilotKitLangGraphInterruptEvent = "CopilotKitLangGraphInterruptEvent",
}
export type RunTimeMetaEvent =
  | {
      type: RuntimeEventTypes.MetaEvent;
      name: RuntimeMetaEventName.LangGraphInterruptEvent;
      value: string;
    }
  | {
      type: RuntimeEventTypes.MetaEvent;
      name: RuntimeMetaEventName.CopilotKitLangGraphInterruptEvent;
      data: { value: string; messages: (TextMessage | ActionExecutionMessage | ResultMessage)[] };
    }
  | {
      type: RuntimeEventTypes.MetaEvent;
      name: RuntimeMetaEventName.LangGraphInterruptResumeEvent;
      data: string;
    };
export type RuntimeErrorEvent = {
  type: RuntimeEventTypes.RunError;
  message: string;
  code?: string;
};
export type RuntimeEvent =
  | { type: RuntimeEventTypes.TextMessageStart; messageId: string; parentMessageId?: string }
  | {
      type: RuntimeEventTypes.TextMessageContent;
      messageId: string;
      content: string;
    }
  | { type: RuntimeEventTypes.TextMessageEnd; messageId: string }
  | {
      type: RuntimeEventTypes.ActionExecutionStart;
      actionExecutionId: string;
      actionName: string;
      parentMessageId?: string;
    }
  | { type: RuntimeEventTypes.ActionExecutionArgs; actionExecutionId: string; args: string }
  | { type: RuntimeEventTypes.ActionExecutionEnd; actionExecutionId: string }
  | {
      type: RuntimeEventTypes.ActionExecutionResult;
      actionName: string;
      actionExecutionId: string;
      result: string;
    }
  | {
      type: RuntimeEventTypes.AgentStateMessage;
      threadId: string;
      agentName: string;
      nodeName: string;
      runId: string;
      active: boolean;
      role: string;
      state: string;
      running: boolean;
    }
  | RunTimeMetaEvent
  | RuntimeErrorEvent;
interface RuntimeEventWithState {
  event: RuntimeEvent | null;
  callActionServerSide: boolean;
  action: Action<any> | null;
  actionExecutionId: string | null;
  args: string;
  actionExecutionParentMessageId: string | null;
}
type EventSourceCallback = (eventStream$: RuntimeEventSubject) => Promise<void>;
export class RuntimeEventSubject extends ReplaySubject<RuntimeEvent> {
  constructor() {
    super();
  }
  sendTextMessageStart({
    messageId,
    parentMessageId,
  }: {
    messageId: string;
    parentMessageId?: string;
  }) {
    this.next({ type: RuntimeEventTypes.TextMessageStart, messageId, parentMessageId });
  }
  sendTextMessageContent({ messageId, content }: { messageId: string; content: string }) {
    this.next({ type: RuntimeEventTypes.TextMessageContent, content, messageId });
  }
  sendTextMessageEnd({ messageId }: { messageId: string }) {
    this.next({ type: RuntimeEventTypes.TextMessageEnd, messageId });
  }
  sendTextMessage(messageId: string, content: string) {
    this.sendTextMessageStart({ messageId });
    this.sendTextMessageContent({ messageId, content });
    this.sendTextMessageEnd({ messageId });
  }
  sendActionExecutionStart({
    actionExecutionId,
    actionName,
    parentMessageId,
  }: {
    actionExecutionId: string;
    actionName: string;
    parentMessageId?: string;
  }) {
    this.next({
      type: RuntimeEventTypes.ActionExecutionStart,
      actionExecutionId,
      actionName,
      parentMessageId,
    });
  }
  sendActionExecutionArgs({
    actionExecutionId,
    args,
  }: {
    actionExecutionId: string;
    args: string;
  }) {
    this.next({ type: RuntimeEventTypes.ActionExecutionArgs, args, actionExecutionId });
  }
  sendActionExecutionEnd({ actionExecutionId }: { actionExecutionId: string }) {
    this.next({ type: RuntimeEventTypes.ActionExecutionEnd, actionExecutionId });
  }
  sendActionExecution({
    actionExecutionId,
    actionName,
    args,
    parentMessageId,
  }: {
    actionExecutionId: string;
    actionName: string;
    args: string;
    parentMessageId?: string;
  }) {
    this.sendActionExecutionStart({ actionExecutionId, actionName, parentMessageId });
    this.sendActionExecutionArgs({ actionExecutionId, args });
    this.sendActionExecutionEnd({ actionExecutionId });
  }
  sendActionExecutionResult({
    actionExecutionId,
    actionName,
    result,
    error,
  }: {
    actionExecutionId: string;
    actionName: string;
    result?: string;
    error?: { code: string; message: string };
  }) {
    this.next({
      type: RuntimeEventTypes.ActionExecutionResult,
      actionName,
      actionExecutionId,
      result: ResultMessage.encodeResult(result, error),
    });
  }
  sendAgentStateMessage({
    threadId,
    agentName,
    nodeName,
    runId,
    active,
    role,
    state,
    running,
  }: {
    threadId: string;
    agentName: string;
    nodeName: string;
    runId: string;
    active: boolean;
    role: string;
    state: string;
    running: boolean;
  }) {
    this.next({
      type: RuntimeEventTypes.AgentStateMessage,
      threadId,
      agentName,
      nodeName,
      runId,
      active,
      role,
      state,
      running,
    });
  }
}
export class RuntimeEventSource {
  private eventStream$ = new RuntimeEventSubject();
  private callback!: EventSourceCallback;
  private errorHandler?: (error: any, context: any) => Promise<void>;
  private errorContext?: any;
  constructor(params?: {
    errorHandler?: (error: any, context: any) => Promise<void>;
    errorContext?: any;
  }) {
    this.errorHandler = params?.errorHandler;
    this.errorContext = params?.errorContext;
  }
  async stream(callback: EventSourceCallback): Promise<void> {
    this.callback = callback;
  }
  sendErrorMessageToChat(message = "An error occurred. Please try again.") {
    const errorMessage = `❌ ${message}`;
    if (!this.callback) {
      this.stream(async (eventStream$) => {
        eventStream$.sendTextMessage(randomId(), errorMessage);
      });
    } else {
      this.eventStream$.sendTextMessage(randomId(), errorMessage);
    }
  }
  processRuntimeEvents({
    serverSideActions,
    guardrailsResult$,
    actionInputsWithoutAgents,
    threadId,
  }: {
    serverSideActions: Action<any>[];
    guardrailsResult$?: Subject<GuardrailsResult>;
    actionInputsWithoutAgents: ActionInput[];
    threadId: string;
  }) {
    this.callback(this.eventStream$).catch(async (error) => {
      // Convert streaming errors to structured errors, but preserve already structured ones
      const structuredError = ensureStructuredError(error, convertStreamingErrorToStructured);
      // Call the runtime error handler if provided
      if (this.errorHandler && this.errorContext) {
        try {
          await this.errorHandler(structuredError, this.errorContext);
        } catch (errorHandlerError) {
          console.error("Error in streaming error handler:", errorHandlerError);
        }
      }
      this.eventStream$.error(structuredError);
      this.eventStream$.complete();
    });
    return this.eventStream$.pipe(
      // track state
      scan(
        (acc, event) => {
          // It seems like this is needed so that rxjs recognizes the object has changed
          // This fixes an issue where action were executed multiple times
          // Not investigating further for now (Markus)
          acc = { ...acc };
          if (event.type === RuntimeEventTypes.ActionExecutionStart) {
            acc.callActionServerSide =
              serverSideActions.find((action) => action.name === event.actionName) !== undefined;
            acc.args = "";
            acc.actionExecutionId = event.actionExecutionId;
            if (acc.callActionServerSide) {
              acc.action = serverSideActions.find((action) => action.name === event.actionName);
            }
            acc.actionExecutionParentMessageId = event.parentMessageId;
          } else if (event.type === RuntimeEventTypes.ActionExecutionArgs) {
            acc.args += event.args;
          }
          acc.event = event;
          return acc;
        },
        {
          event: null,
          callActionServerSide: false,
          args: "",
          actionExecutionId: null,
          action: null,
          actionExecutionParentMessageId: null,
        } as RuntimeEventWithState,
      ),
      concatMap((eventWithState) => {
        if (
          eventWithState.event!.type === RuntimeEventTypes.ActionExecutionEnd &&
          eventWithState.callActionServerSide
        ) {
          const toolCallEventStream$ = new RuntimeEventSubject();
          executeAction(
            toolCallEventStream$,
            guardrailsResult$ ? guardrailsResult$ : null,
            eventWithState.action!,
            eventWithState.args,
            eventWithState.actionExecutionParentMessageId,
            eventWithState.actionExecutionId,
            actionInputsWithoutAgents,
            threadId,
          ).catch((error) => {});
          telemetry.capture("oss.runtime.server_action_executed", {});
          return concat(of(eventWithState.event!), toolCallEventStream$).pipe(
            catchError((error) => {
              // Convert streaming errors to structured errors and send as action result, but preserve already structured ones
              const structuredError = ensureStructuredError(
                error,
                convertStreamingErrorToStructured,
              );
              // Call the runtime error handler if provided
              if (this.errorHandler && this.errorContext) {
                // Use from() to handle async error handler
                from(
                  this.errorHandler(structuredError, {
                    ...this.errorContext,
                    action: {
                      name: eventWithState.action!.name,
                      executionId: eventWithState.actionExecutionId,
                    },
                  }),
                ).subscribe({
                  error: (errorHandlerError) => {
                    console.error("Error in action execution error handler:", errorHandlerError);
                  },
                });
              }
              toolCallEventStream$.sendActionExecutionResult({
                actionExecutionId: eventWithState.actionExecutionId!,
                actionName: eventWithState.action!.name,
                error: {
                  code: structuredError.code,
                  message: structuredError.message,
                },
              });
              return EMPTY;
            }),
          );
        } else {
          return of(eventWithState.event!);
        }
      }),
    );
  }
}
async function executeAction(
  eventStream$: RuntimeEventSubject,
  guardrailsResult$: Subject<GuardrailsResult> | null,
  action: Action<any>,
  actionArguments: string,
  actionExecutionParentMessageId: string | null,
  actionExecutionId: string,
  actionInputsWithoutAgents: ActionInput[],
  threadId: string,
) {
  if (guardrailsResult$) {
    const { status } = await firstValueFrom(guardrailsResult$);
    if (status === "denied") {
      eventStream$.complete();
      return;
    }
  }
  // Prepare arguments for function calling
  let args: Record<string, any>[] = [];
  if (actionArguments) {
    try {
      args = JSON.parse(actionArguments);
    } catch (e) {
      console.error("Action argument unparsable", { actionArguments });
      eventStream$.sendActionExecutionResult({
        actionExecutionId,
        actionName: action.name,
        error: {
          code: "INVALID_ARGUMENTS",
          message: "Failed to parse action arguments",
        },
      });
      return;
    }
  }
  // handle LangGraph agents
  if (isRemoteAgentAction(action)) {
    const result = `${action.name} agent started`;
    const agentExecution = plainToInstance(ActionExecutionMessage, {
      id: actionExecutionId,
      createdAt: new Date(),
      name: action.name,
      arguments: JSON.parse(actionArguments),
      parentMessageId: actionExecutionParentMessageId ?? actionExecutionId,
    });
    const agentExecutionResult = plainToInstance(ResultMessage, {
      id: "result-" + actionExecutionId,
      createdAt: new Date(),
      actionExecutionId,
      actionName: action.name,
      result,
    });
    eventStream$.sendActionExecutionResult({
      actionExecutionId,
      actionName: action.name,
      result,
    });
    const stream = await action.remoteAgentHandler({
      name: action.name,
      threadId,
      actionInputsWithoutAgents,
      additionalMessages: [agentExecution, agentExecutionResult],
    });
    // forward to eventStream$
    from(stream).subscribe({
      next: (event) => eventStream$.next(event),
      error: (err) => {
        // Preserve already structured CopilotKit errors, only convert unstructured errors
        const structuredError = ensureStructuredError(err, convertStreamingErrorToStructured);
        eventStream$.sendActionExecutionResult({
          actionExecutionId,
          actionName: action.name,
          error: {
            code: structuredError.code,
            message: structuredError.message,
          },
        });
        eventStream$.complete();
      },
      complete: () => eventStream$.complete(),
    });
  } else {
    // call the function
    try {
      const result = await action.handler?.(args);
      await streamLangChainResponse({
        result,
        eventStream$,
        actionExecution: {
          name: action.name,
          id: actionExecutionId,
        },
      });
    } catch (e) {
      console.error("Error in action handler", e);
      eventStream$.sendActionExecutionResult({
        actionExecutionId,
        actionName: action.name,
        error: {
          code: "HANDLER_ERROR",
          message: e.message,
        },
      });
      eventStream$.complete();
    }
  }
}
function convertStreamingErrorToStructured(error: any): CopilotKitError {
  // Determine a more helpful error message based on context
  let helpfulMessage = generateHelpfulErrorMessage(error, "event streaming connection");
  // For network-related errors, use CopilotKitLowLevelError to preserve the original error
  if (
    error?.message?.includes("fetch failed") ||
    error?.message?.includes("ECONNREFUSED") ||
    error?.message?.includes("ENOTFOUND") ||
    error?.message?.includes("ETIMEDOUT") ||
    error?.message?.includes("terminated") ||
    error?.cause?.code === "UND_ERR_SOCKET" ||
    error?.message?.includes("other side closed") ||
    error?.code === "UND_ERR_SOCKET"
  ) {
    return new CopilotKitLowLevelError({
      error: error instanceof Error ? error : new Error(String(error)),
      url: "event streaming connection",
      message: helpfulMessage,
    });
  }
  // For all other errors, preserve the raw error in a basic CopilotKitError
  return new CopilotKitError({
    message: helpfulMessage,
    code: CopilotKitErrorCode.UNKNOWN,
    severity: Severity.CRITICAL,
  });
}