UNPKG

@copilotkit/runtime

Version:

<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />

156 lines (154 loc) • 6.43 kB
require("reflect-metadata"); const require_utils = require('./utils.cjs'); //#region src/service-adapters/openai/openai-assistant-adapter.ts var OpenAIAssistantAdapter = class { get name() { return "OpenAIAssistantAdapter"; } constructor(params) { this.keepSystemRole = false; if (params.openai) this._openai = params.openai; this.codeInterpreterEnabled = params.codeInterpreterEnabled === false || true; this.fileSearchEnabled = params.fileSearchEnabled === false || true; this.assistantId = params.assistantId; this.disableParallelToolCalls = params?.disableParallelToolCalls || false; this.keepSystemRole = params?.keepSystemRole ?? false; } ensureOpenAI() { if (!this._openai) { const OpenAI = require("openai").default; this._openai = new OpenAI({}); } return this._openai; } async process(request) { const { messages, actions, eventSource, runId, forwardedParameters } = request; let threadId = request.extensions?.openaiAssistantAPI?.threadId; const openai = this.ensureOpenAI(); if (!threadId) threadId = (await openai.beta.threads.create()).id; const lastMessage = messages.at(-1); let nextRunId = void 0; if (lastMessage.isResultMessage() && runId) nextRunId = await this.submitToolOutputs(threadId, runId, messages, eventSource); else if (lastMessage.isTextMessage()) nextRunId = await this.submitUserMessage(threadId, messages, actions, eventSource, forwardedParameters); else throw new Error("No actionable message found in the messages"); return { runId: nextRunId, threadId, extensions: { ...request.extensions, openaiAssistantAPI: { threadId, runId: nextRunId } } }; } async submitToolOutputs(threadId, runId, messages, eventSource) { const openai = this.ensureOpenAI(); let run = await require_utils.retrieveThreadRun(openai, threadId, runId); if (!run.required_action) throw new Error("No tool outputs required"); const toolCallsIds = run.required_action.submit_tool_outputs.tool_calls.map((toolCall) => toolCall.id); const resultMessages = messages.filter((message) => message.isResultMessage() && toolCallsIds.includes(message.actionExecutionId)); if (toolCallsIds.length != resultMessages.length) throw new Error("Number of function results does not match the number of tool calls"); const stream = require_utils.submitToolOutputsStream(openai, threadId, runId, { tool_outputs: resultMessages.map((message) => { return { tool_call_id: message.actionExecutionId, output: message.result }; }), ...this.disableParallelToolCalls && { parallel_tool_calls: false } }); await this.streamResponse(stream, eventSource); return runId; } async submitUserMessage(threadId, messages, actions, eventSource, forwardedParameters) { const openai = this.ensureOpenAI(); messages = [...messages]; const instructionsMessage = messages.shift(); const instructions = instructionsMessage.isTextMessage() ? instructionsMessage.content : ""; const userMessage = messages.map((m) => require_utils.convertMessageToOpenAIMessage(m, { keepSystemRole: this.keepSystemRole })).map(require_utils.convertSystemMessageToAssistantAPI).at(-1); if (userMessage.role !== "user") throw new Error("No user message found"); await openai.beta.threads.messages.create(threadId, { role: "user", content: userMessage.content }); const tools = [ ...actions.map(require_utils.convertActionInputToOpenAITool), ...this.codeInterpreterEnabled ? [{ type: "code_interpreter" }] : [], ...this.fileSearchEnabled ? [{ type: "file_search" }] : [] ]; let stream = openai.beta.threads.runs.stream(threadId, { assistant_id: this.assistantId, instructions, tools, ...forwardedParameters?.maxTokens && { max_completion_tokens: forwardedParameters.maxTokens }, ...this.disableParallelToolCalls && { parallel_tool_calls: false } }); await this.streamResponse(stream, eventSource); return getRunIdFromStream(stream); } async streamResponse(stream, eventSource) { eventSource.stream(async (eventStream$) => { let inFunctionCall = false; let currentMessageId; let currentToolCallId; for await (const chunk of stream) switch (chunk.event) { case "thread.message.created": if (inFunctionCall) eventStream$.sendActionExecutionEnd({ actionExecutionId: currentToolCallId }); currentMessageId = chunk.data.id; eventStream$.sendTextMessageStart({ messageId: currentMessageId }); break; case "thread.message.delta": if (chunk.data.delta.content?.[0].type === "text") eventStream$.sendTextMessageContent({ messageId: currentMessageId, content: chunk.data.delta.content?.[0].text.value }); break; case "thread.message.completed": eventStream$.sendTextMessageEnd({ messageId: currentMessageId }); break; case "thread.run.step.delta": let toolCallId; let toolCallName; let toolCallArgs; if (chunk.data.delta.step_details.type === "tool_calls" && chunk.data.delta.step_details.tool_calls?.[0].type === "function") { toolCallId = chunk.data.delta.step_details.tool_calls?.[0].id; toolCallName = chunk.data.delta.step_details.tool_calls?.[0].function.name; toolCallArgs = chunk.data.delta.step_details.tool_calls?.[0].function.arguments; } if (toolCallName && toolCallId) { if (inFunctionCall) eventStream$.sendActionExecutionEnd({ actionExecutionId: currentToolCallId }); inFunctionCall = true; currentToolCallId = toolCallId; eventStream$.sendActionExecutionStart({ actionExecutionId: currentToolCallId, parentMessageId: chunk.data.id, actionName: toolCallName }); } else if (toolCallArgs) eventStream$.sendActionExecutionArgs({ actionExecutionId: currentToolCallId, args: toolCallArgs }); break; } if (inFunctionCall) eventStream$.sendActionExecutionEnd({ actionExecutionId: currentToolCallId }); eventStream$.complete(); }); } }; function getRunIdFromStream(stream) { return new Promise((resolve, reject) => { let runIdGetter = (event) => { if (event.event === "thread.run.created") { const runId = event.data.id; stream.off("event", runIdGetter); resolve(runId); } }; stream.on("event", runIdGetter); }); } //#endregion exports.OpenAIAssistantAdapter = OpenAIAssistantAdapter; //# sourceMappingURL=openai-assistant-adapter.cjs.map