@copilotkit/runtime-client-gql
Version:
<div align="center"> <a href="https://copilotkit.ai" target="_blank"> <img src="https://github.com/copilotkit/copilotkit/raw/main/assets/banner.png" alt="CopilotKit Logo"> </a>
215 lines (203 loc) • 6.62 kB
text/typescript
import {
GenerateCopilotResponseMutation,
MessageInput,
MessageStatusCode,
} from "../graphql/@generated/graphql";
import {
ActionExecutionMessage,
AgentStateMessage,
Message,
ResultMessage,
TextMessage,
} from "./types";
import untruncateJson from "untruncate-json";
import { parseJson } from "@copilotkit/shared";
export function filterAgentStateMessages(messages: Message[]): Message[] {
return messages.filter((message) => !message.isAgentStateMessage());
}
export function convertMessagesToGqlInput(messages: Message[]): MessageInput[] {
return messages.map((message) => {
if (message.isTextMessage()) {
return {
id: message.id,
createdAt: message.createdAt,
textMessage: {
content: message.content,
role: message.role as any,
parentMessageId: message.parentMessageId,
},
};
} else if (message.isActionExecutionMessage()) {
return {
id: message.id,
createdAt: message.createdAt,
actionExecutionMessage: {
name: message.name,
arguments: JSON.stringify(message.arguments),
parentMessageId: message.parentMessageId,
},
};
} else if (message.isResultMessage()) {
return {
id: message.id,
createdAt: message.createdAt,
resultMessage: {
result: message.result,
actionExecutionId: message.actionExecutionId,
actionName: message.actionName,
},
};
} else if (message.isAgentStateMessage()) {
return {
id: message.id,
createdAt: message.createdAt,
agentStateMessage: {
threadId: message.threadId,
role: message.role,
agentName: message.agentName,
nodeName: message.nodeName,
runId: message.runId,
active: message.active,
running: message.running,
state: JSON.stringify(message.state),
},
};
} else {
throw new Error("Unknown message type");
}
});
}
export function filterAdjacentAgentStateMessages(
messages: GenerateCopilotResponseMutation["generateCopilotResponse"]["messages"],
): GenerateCopilotResponseMutation["generateCopilotResponse"]["messages"] {
const filteredMessages: GenerateCopilotResponseMutation["generateCopilotResponse"]["messages"] =
[];
messages.forEach((message, i) => {
// keep all other message types
if (message.__typename !== "AgentStateMessageOutput") {
filteredMessages.push(message);
} else {
const prevAgentStateMessageIndex = filteredMessages.findIndex(
// TODO: also check runId
(m) => m.__typename === "AgentStateMessageOutput" && m.agentName === message.agentName,
);
if (prevAgentStateMessageIndex === -1) {
filteredMessages.push(message);
} else {
filteredMessages[prevAgentStateMessageIndex] = message;
}
}
});
return filteredMessages;
}
export function convertGqlOutputToMessages(
messages: GenerateCopilotResponseMutation["generateCopilotResponse"]["messages"],
): Message[] {
return messages.map((message) => {
if (message.__typename === "TextMessageOutput") {
return new TextMessage({
id: message.id,
role: message.role,
content: message.content.join(""),
parentMessageId: message.parentMessageId,
createdAt: new Date(),
status: message.status || { code: MessageStatusCode.Pending },
});
} else if (message.__typename === "ActionExecutionMessageOutput") {
return new ActionExecutionMessage({
id: message.id,
name: message.name,
arguments: getPartialArguments(message.arguments),
parentMessageId: message.parentMessageId,
createdAt: new Date(),
status: message.status || { code: MessageStatusCode.Pending },
});
} else if (message.__typename === "ResultMessageOutput") {
return new ResultMessage({
id: message.id,
result: message.result,
actionExecutionId: message.actionExecutionId,
actionName: message.actionName,
createdAt: new Date(),
status: message.status || { code: MessageStatusCode.Pending },
});
} else if (message.__typename === "AgentStateMessageOutput") {
return new AgentStateMessage({
id: message.id,
threadId: message.threadId,
role: message.role,
agentName: message.agentName,
nodeName: message.nodeName,
runId: message.runId,
active: message.active,
running: message.running,
state: parseJson(message.state, {}),
createdAt: new Date(),
});
}
throw new Error("Unknown message type");
});
}
export function loadMessagesFromJsonRepresentation(json: any[]): Message[] {
const result: Message[] = [];
for (const item of json) {
if ("content" in item) {
result.push(
new TextMessage({
id: item.id,
role: item.role,
content: item.content,
parentMessageId: item.parentMessageId,
createdAt: item.createdAt || new Date(),
status: item.status || { code: MessageStatusCode.Success },
}),
);
} else if ("arguments" in item) {
result.push(
new ActionExecutionMessage({
id: item.id,
name: item.name,
arguments: item.arguments,
parentMessageId: item.parentMessageId,
createdAt: item.createdAt || new Date(),
status: item.status || { code: MessageStatusCode.Success },
}),
);
} else if ("result" in item) {
result.push(
new ResultMessage({
id: item.id,
result: item.result,
actionExecutionId: item.actionExecutionId,
actionName: item.actionName,
createdAt: item.createdAt || new Date(),
status: item.status || { code: MessageStatusCode.Success },
}),
);
} else if ("state" in item) {
result.push(
new AgentStateMessage({
id: item.id,
threadId: item.threadId,
role: item.role,
agentName: item.agentName,
nodeName: item.nodeName,
runId: item.runId,
active: item.active,
running: item.running,
state: item.state,
createdAt: item.createdAt || new Date(),
}),
);
}
}
return result;
}
function getPartialArguments(args: string[]) {
try {
if (!args.length) return {};
return JSON.parse(untruncateJson(args.join("")));
} catch (e) {
return {};
}
}