@copilotkit/runtime
Version:
<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />
307 lines (276 loc) • 9.86 kB
text/typescript
import {
ActionExecutionMessage,
Message,
ResultMessage,
TextMessage,
} from "../../graphql/types/converted";
import {
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
} from "@langchain/core/messages";
import { DynamicStructuredTool } from "@langchain/core/tools";
import { z } from "zod";
import { ActionInput } from "../../graphql/inputs/action.input";
import { LangChainReturnType } from "./types";
import { RuntimeEventSubject } from "../events";
import { randomId, convertJsonSchemaToZodSchema } from "@copilotkit/shared";
export function convertMessageToLangChainMessage(message: Message): BaseMessage {
if (message.isTextMessage()) {
if (message.role == "user") {
return new HumanMessage(message.content);
} else if (message.role == "assistant") {
return new AIMessage(message.content);
} else if (message.role === "system") {
return new SystemMessage(message.content);
}
} else if (message.isActionExecutionMessage()) {
return new AIMessage({
content: "",
tool_calls: [
{
id: message.id,
args: message.arguments,
name: message.name,
},
],
});
} else if (message.isResultMessage()) {
return new ToolMessage({
content: message.result,
tool_call_id: message.actionExecutionId,
});
}
}
export function convertActionInputToLangChainTool(actionInput: ActionInput): any {
return new DynamicStructuredTool({
name: actionInput.name,
description: actionInput.description,
schema: convertJsonSchemaToZodSchema(
JSON.parse(actionInput.jsonSchema),
true,
) as z.ZodObject<any>,
func: async () => {
return "";
},
});
}
interface StreamLangChainResponseParams {
result: LangChainReturnType;
eventStream$: RuntimeEventSubject;
actionExecution?: {
id: string;
name: string;
};
}
function getConstructorName(object: any): string {
if (object && typeof object === "object" && object.constructor && object.constructor.name) {
return object.constructor.name;
}
return "";
}
function isAIMessage(message: any): message is AIMessage {
return Object.prototype.toString.call(message) === "[object AIMessage]";
}
function isAIMessageChunk(message: any): message is AIMessageChunk {
return Object.prototype.toString.call(message) === "[object AIMessageChunk]";
}
function isBaseMessageChunk(message: any): message is BaseMessageChunk {
return Object.prototype.toString.call(message) === "[object BaseMessageChunk]";
}
function maybeSendActionExecutionResultIsMessage(
eventStream$: RuntimeEventSubject,
actionExecution?: { id: string; name: string },
) {
// language models need a result after the function call
// we simply let them know that we are sending a message
if (actionExecution) {
eventStream$.sendActionExecutionResult({
actionExecutionId: actionExecution.id,
actionName: actionExecution.name,
result: "Sending a message",
});
}
}
export async function streamLangChainResponse({
result,
eventStream$,
actionExecution,
}: StreamLangChainResponseParams) {
// We support several types of return values from LangChain functions:
// 1. string
if (typeof result === "string") {
if (!actionExecution) {
// Just send one chunk with the string as the content.
eventStream$.sendTextMessage(randomId(), result);
} else {
// Send as a result
eventStream$.sendActionExecutionResult({
actionExecutionId: actionExecution.id,
actionName: actionExecution.name,
result: result,
});
}
}
// 2. AIMessage
// Send the content and function call of the AIMessage as the content of the chunk.
else if (isAIMessage(result)) {
maybeSendActionExecutionResultIsMessage(eventStream$, actionExecution);
if (result.content) {
eventStream$.sendTextMessage(randomId(), result.content as string);
}
for (const toolCall of result.tool_calls) {
eventStream$.sendActionExecution({
actionExecutionId: toolCall.id || randomId(),
actionName: toolCall.name,
args: JSON.stringify(toolCall.args),
});
}
}
// 3. BaseMessageChunk
// Send the content and function call of the AIMessage as the content of the chunk.
else if (isBaseMessageChunk(result)) {
maybeSendActionExecutionResultIsMessage(eventStream$, actionExecution);
if (result.lc_kwargs?.content) {
eventStream$.sendTextMessage(randomId(), result.content as string);
}
if (result.lc_kwargs?.tool_calls) {
for (const toolCall of result.lc_kwargs?.tool_calls) {
eventStream$.sendActionExecution({
actionExecutionId: toolCall.id || randomId(),
actionName: toolCall.name,
args: JSON.stringify(toolCall.args),
});
}
}
}
// 4. IterableReadableStream
// Stream the result of the LangChain function.
else if (result && "getReader" in result) {
maybeSendActionExecutionResultIsMessage(eventStream$, actionExecution);
let reader = result.getReader();
let mode: "function" | "message" | null = null;
let currentMessageId: string;
const toolCallDetails = {
name: null,
id: null,
index: null,
prevIndex: null,
};
while (true) {
try {
const { done, value } = await reader.read();
let toolCallName: string | undefined = undefined;
let toolCallId: string | undefined = undefined;
let toolCallArgs: string | undefined = undefined;
let hasToolCall: boolean = false;
let content = "";
if (value && value.content) {
content = Array.isArray(value.content)
? (((value.content[0] as any)?.text ?? "") as string)
: value.content;
}
if (isAIMessageChunk(value)) {
let chunk = value.tool_call_chunks?.[0];
toolCallArgs = chunk?.args;
hasToolCall = chunk != undefined;
if (chunk?.name) toolCallDetails.name = chunk.name;
// track different index on the same tool cool
if (chunk?.index != null) {
toolCallDetails.index = chunk.index; // 1
if (toolCallDetails.prevIndex == null) toolCallDetails.prevIndex = chunk.index;
}
// Differentiate when calling the same tool but with different index
if (chunk?.id)
toolCallDetails.id = chunk.index != null ? `${chunk.id}-idx-${chunk.index}` : chunk.id;
// Assign to internal variables that the entire script here knows how to work with
toolCallName = toolCallDetails.name;
toolCallId = toolCallDetails.id;
} else if (isBaseMessageChunk(value)) {
let chunk = value.additional_kwargs?.tool_calls?.[0];
toolCallName = chunk?.function?.name;
toolCallId = chunk?.id;
toolCallArgs = chunk?.function?.arguments;
hasToolCall = chunk?.function != undefined;
}
// When switching from message to function or vice versa,
// send the respective end event.
// If toolCallName is defined, it means a new tool call starts.
if (mode === "message" && (toolCallId || done)) {
mode = null;
eventStream$.sendTextMessageEnd({ messageId: currentMessageId });
} else if (mode === "function" && (!hasToolCall || done)) {
mode = null;
eventStream$.sendActionExecutionEnd({ actionExecutionId: toolCallId });
}
if (done) {
break;
}
// If we send a new message type, send the appropriate start event.
if (mode === null) {
if (hasToolCall && toolCallId && toolCallName) {
mode = "function";
eventStream$.sendActionExecutionStart({
actionExecutionId: toolCallId,
actionName: toolCallName,
parentMessageId: value.lc_kwargs?.id,
});
} else if (content) {
mode = "message";
currentMessageId = value.lc_kwargs?.id || randomId();
eventStream$.sendTextMessageStart({ messageId: currentMessageId });
}
}
// send the content events
if (mode === "message" && content) {
eventStream$.sendTextMessageContent({
messageId: currentMessageId,
content,
});
} else if (mode === "function" && toolCallArgs) {
// For calls of the same tool with different index, we seal last tool call and register a new one
if (toolCallDetails.index !== toolCallDetails.prevIndex) {
eventStream$.sendActionExecutionEnd({ actionExecutionId: toolCallId });
eventStream$.sendActionExecutionStart({
actionExecutionId: toolCallId,
actionName: toolCallName,
parentMessageId: value.lc_kwargs?.id,
});
toolCallDetails.prevIndex = toolCallDetails.index;
}
eventStream$.sendActionExecutionArgs({
actionExecutionId: toolCallId,
args: toolCallArgs,
});
}
} catch (error) {
console.error("Error reading from stream", error);
break;
}
}
} else if (actionExecution) {
eventStream$.sendActionExecutionResult({
actionExecutionId: actionExecution.id,
actionName: actionExecution.name,
result: encodeResult(result),
});
}
// unsupported type
else {
throw new Error("Invalid return type from LangChain function.");
}
eventStream$.complete();
}
function encodeResult(result: any): string {
if (result === undefined) {
return "";
} else if (typeof result === "string") {
return result;
} else {
return JSON.stringify(result);
}
}