UNPKG

@copilotkit/runtime

Version:

<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />

135 lines (133 loc) 5.35 kB
import "reflect-metadata"; import { convertActionInputToOpenAITool, convertMessageToOpenAIMessage, getChatCompletionsForStreaming, limitMessagesToTokenCount } from "./utils.mjs"; import { convertServiceAdapterError } from "../shared/error-utils.mjs"; import { getSdkClientOptions } from "../shared/sdk-client-utils.mjs"; import { createOpenAI } from "@ai-sdk/openai"; import Openai from "openai"; import { randomUUID } from "@copilotkit/shared"; //#region src/service-adapters/openai/openai-adapter.ts const DEFAULT_MODEL = "gpt-4o"; var OpenAIAdapter = class { get openai() { return this._openai; } get name() { return "OpenAIAdapter"; } constructor(params) { this.model = DEFAULT_MODEL; this.provider = "openai"; this.disableParallelToolCalls = false; this.keepSystemRole = false; if (params?.openai) this._openai = params.openai; if (params?.model) this.model = params.model; this.disableParallelToolCalls = params?.disableParallelToolCalls || false; this.keepSystemRole = params?.keepSystemRole ?? false; this.maxInputTokens = params?.maxInputTokens; } getLanguageModel() { const openai = this.ensureOpenAI(); const options = getSdkClientOptions(openai); return createOpenAI({ baseURL: openai.baseURL, apiKey: openai.apiKey, organization: openai.organization ?? void 0, project: openai.project ?? void 0, headers: options.defaultHeaders, fetch: options.fetch })(this.model); } ensureOpenAI() { if (!this._openai) this._openai = new Openai(); return this._openai; } async process(request) { const { threadId: threadIdFromRequest, model = this.model, messages, actions, eventSource, forwardedParameters } = request; const tools = actions.map(convertActionInputToOpenAITool); const threadId = threadIdFromRequest ?? randomUUID(); const validToolUseIds = /* @__PURE__ */ new Set(); for (const message of messages) if (message.isActionExecutionMessage()) validToolUseIds.add(message.id); let openaiMessages = messages.filter((message) => { if (message.isResultMessage()) { if (!validToolUseIds.has(message.actionExecutionId)) return false; validToolUseIds.delete(message.actionExecutionId); return true; } return true; }).map((m) => convertMessageToOpenAIMessage(m, { keepSystemRole: this.keepSystemRole })); openaiMessages = limitMessagesToTokenCount(openaiMessages, tools, model, this.maxInputTokens); let toolChoice = forwardedParameters?.toolChoice; if (forwardedParameters?.toolChoice === "function") toolChoice = { type: "function", function: { name: forwardedParameters.toolChoiceFunctionName } }; try { const stream = getChatCompletionsForStreaming(this.ensureOpenAI()).stream({ model, stream: true, messages: openaiMessages, ...tools.length > 0 && { tools }, ...forwardedParameters?.maxTokens && { max_completion_tokens: forwardedParameters.maxTokens }, ...forwardedParameters?.stop && { stop: forwardedParameters.stop }, ...toolChoice && { tool_choice: toolChoice }, ...this.disableParallelToolCalls && { parallel_tool_calls: false }, ...forwardedParameters?.temperature && { temperature: forwardedParameters.temperature } }); eventSource.stream(async (eventStream$) => { let mode = null; let currentMessageId; let currentToolCallId; try { for await (const chunk of stream) { if (chunk.choices.length === 0) continue; const toolCall = chunk.choices[0].delta.tool_calls?.[0]; const content = chunk.choices[0].delta.content; if (mode === "message" && toolCall?.id) { mode = null; eventStream$.sendTextMessageEnd({ messageId: currentMessageId }); } else if (mode === "function" && (toolCall === void 0 || toolCall?.id)) { mode = null; eventStream$.sendActionExecutionEnd({ actionExecutionId: currentToolCallId }); } if (mode === null) { if (toolCall?.id) { mode = "function"; currentToolCallId = toolCall.id; eventStream$.sendActionExecutionStart({ actionExecutionId: currentToolCallId, parentMessageId: chunk.id, actionName: toolCall.function.name }); } else if (content) { mode = "message"; currentMessageId = chunk.id; eventStream$.sendTextMessageStart({ messageId: currentMessageId }); } } if (mode === "message" && content) eventStream$.sendTextMessageContent({ messageId: currentMessageId, content }); else if (mode === "function" && toolCall?.function?.arguments) eventStream$.sendActionExecutionArgs({ actionExecutionId: currentToolCallId, args: toolCall.function.arguments }); } if (mode === "message") eventStream$.sendTextMessageEnd({ messageId: currentMessageId }); else if (mode === "function") eventStream$.sendActionExecutionEnd({ actionExecutionId: currentToolCallId }); } catch (error) { console.error("[OpenAI] Error during API call:", error); throw convertServiceAdapterError(error, "OpenAI"); } eventStream$.complete(); }); } catch (error) { console.error("[OpenAI] Error during API call:", error); throw convertServiceAdapterError(error, "OpenAI"); } return { threadId }; } }; //#endregion export { OpenAIAdapter }; //# sourceMappingURL=openai-adapter.mjs.map