UNPKG

openai-plugins

Version:

A TypeScript library that provides an OpenAI-compatible client for the Model Context Protocol (MCP).

1,217 lines (1,069 loc) 37.8 kB
import OriginalOpenAI from "openai"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; import { CallToolResultSchema, ToolListChangedNotificationSchema, } from "@modelcontextprotocol/sdk/types.js"; import { EventSource } from "eventsource"; declare global { var EventSource: typeof EventSource; } if (typeof globalThis.EventSource === "undefined") { globalThis.EventSource = EventSource; } // Use more specific types for better compatibility with OpenAI SDK interface ChatCompletionRole { role: "system" | "user" | "assistant" | "tool" | "function"; content: string; tool_calls?: any[]; tool_call_id?: string; name?: string; } interface ChatCompletionParams { model: string; messages: ChatCompletionRole[]; max_tokens?: number; temperature?: number; stream?: boolean; tools?: any[]; tool_choice?: string | object; response_format?: { type: string }; [key: string]: any; // Allow other properties } // Improved logging system with export to allow external configuration const LVL = { debug: 0, info: 1, warn: 2, error: 3 } as const; export type MCP_LogLevel = keyof typeof LVL; type LogLevelValue = (typeof LVL)[MCP_LogLevel]; // Export LOG_LVL to allow external configuration export let MCP_LOG_LVL: LogLevelValue = LVL["debug"]; // Track current log level to avoid duplicate log messages export const setMcpLogLevel = (level: MCP_LogLevel): void => { if (LVL[level] !== undefined) { // Only log if the level is actually changing const currentLevel = Object.keys(LVL).find( (key) => LVL[key as MCP_LogLevel] === MCP_LOG_LVL, ) as MCP_LogLevel; const isChanging = currentLevel !== level; MCP_LOG_LVL = LVL[level] as LogLevelValue; if (isChanging) { log(LVL.info, `MCP log level set to ${level.toUpperCase()}`); } } else { log(LVL.warn, `Invalid MCP log level: ${level}. Using current level.`); } }; // Improved log function that consistently outputs to console export const log = (lvl: number, msg: string): void => { // Only log if the level is greater than or equal to MCP_LOG_LVL if (lvl < MCP_LOG_LVL) return; const tag = ["DEBUG", "INFO", "WARN", "ERROR"][lvl]; const logMsg = `[${new Date().toISOString()}] [MCP] [${tag}] ${msg}`; // Always use console.log for visibility console.log(logMsg); // Add to specific log levels for filtering if (lvl >= LVL.error) { console.error(logMsg); } else if (lvl >= LVL.warn) { console.warn(logMsg); } else if (lvl === LVL.debug) { console.debug(logMsg); } }; interface Provider { name: string; regex: RegExp; baseURL: string; keyEnv: string; } type Message = { role: "system" | "user" | "assistant" | "tool"; content: string; tool_calls?: ToolCall[]; tool_call_id?: string; name?: string; }; type ToolCall = { id: string; function: { name: string; arguments: string; }; }; // Updated MCPConfig to match the interface structure export interface MCPConfig { serverUrl?: string; serverUrls?: string[]; headers?: Record<string, string>; maxToolCalls?: number; toolTimeoutSec?: number; disconnectAfterUse?: boolean; connectionTimeoutMs?: number; maxMessageGroups?: number; finalResponseSystemPrompt?: string; secondPassSystemPrompt?: string; modelName?: string; maxOutputTokens?: number; tokenRateLimit?: number; rateLimitWindowMs?: number; noWaitOnTpm?: boolean; logLevel?: MCP_LogLevel; forceCleanupTimeoutMs?: number; // New option for forcing cleanup } // Internal MCPConfig used by MCPClient interface InternalMCPConfig { serverUrls: string[]; headers: Record<string, string>; finalResponseSystemPrompt?: string; secondPassSystemPrompt: string; modelName: string; maxOutputTokens: number; maxToolCalls: number; toolTimeoutSec: number; disconnectAfterUse: boolean; connectionTimeoutMs: number; maxMessageGroups: number; tokenRateLimit: number; rateLimitWindowMs: number; noWaitOnTpm: boolean; logLevel?: MCP_LogLevel; forceCleanupTimeoutMs: number; // New option for forcing cleanup } interface MCPTransport extends SSEClientTransport { eventSource?: EventSource; } // Track active connections for cleanup const activeConnections = new Set<MCPClient>(); // Function to force cleanup of all connections const forceCleanupAllConnections = async () => { log( LVL.warn, `Force cleaning up ${activeConnections.size} active connections`, ); const disconnectPromises = Array.from(activeConnections).map((client) => client.disconnect().catch((e) => { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error during forced disconnect: ${errorMessage}`); }), ); await Promise.allSettled(disconnectPromises); activeConnections.clear(); }; // Add a global process exit handler to clean up connections if (typeof process !== "undefined") { process.on("exit", () => { try { // Synchronous cleanup on exit Array.from(activeConnections).forEach((client) => { try { if (client.transport?.eventSource) { client.transport.eventSource.close(); } } catch (e) { // Ignore errors during process exit } }); activeConnections.clear(); } catch (e) { // Ignore errors during process exit } }); } // Enhanced asIterable to support automatic cleanup of streams async function* asIterable(resp: any, cleanup?: () => Promise<void>) { // For async iterables, we iterate and handle cleanup if (resp && Symbol.asyncIterator in resp) { try { for await (const x of resp) yield x; } catch (e) { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error in async iteration: ${errorMessage}`); // Rethrow to allow proper handling upstream throw e; } finally { // Run cleanup if provided if (cleanup) { try { await cleanup(); } catch (e) { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error in iteration cleanup: ${errorMessage}`); } } } return; } // For non-iterables, handle as before and run cleanup at the end const content = resp?.choices?.[0]?.message?.content ?? resp?.choices?.[0]?.delta?.content ?? resp?.content ?? JSON.stringify(resp); yield { choices: [{ delta: { content } }] }; // Run cleanup if provided if (cleanup) { try { await cleanup(); } catch (e) { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error in non-iteration cleanup: ${errorMessage}`); } } } const PROVIDERS: Provider[] = [ { name: "openai", regex: /^(gpt|text-|davinci|curie|babbage|ada|dall-e)/i, baseURL: "https://api.openai.com/v1", keyEnv: "OPENAI_API_KEY", }, { name: "anthropic", regex: /^claude/i, baseURL: "https://api.anthropic.com/v1", keyEnv: "ANTHROPIC_API_KEY", }, { name: "gemini", regex: /^gemini/i, baseURL: "https://generativelanguage.googleapis.com/v1beta/openai", keyEnv: "GEMINI_API_KEY", }, ]; let globalApiKey: string | null = null; const providerCache = new Map<string, OriginalOpenAI>(); function providerFor(model: string): OriginalOpenAI { const info = PROVIDERS.find((p) => p.regex.test(model)) || PROVIDERS[0]; log(LVL.debug, `Using provider ${info.name} for model ${model}`); if (!providerCache.has(info.name)) { log(LVL.debug, `Creating new provider instance for ${info.name}`); providerCache.set( info.name, new OriginalOpenAI({ apiKey: globalApiKey || process.env[info.keyEnv], baseURL: info.baseURL, }), ); } return providerCache.get(info.name)!; } // Update Plugin interface to match expected structure export interface Plugin { name: string; handle: ( params: OriginalOpenAI.Chat.ChatCompletionCreateParams, next: ( p: OriginalOpenAI.Chat.ChatCompletionCreateParams, ) => Promise<OriginalOpenAI.Chat.ChatCompletion>, ) => Promise<OriginalOpenAI.Chat.ChatCompletion>; } // Internal Plugin type used internally type InternalPlugin = { name: string; handle: PluginHandler; }; type PluginHandler = ( params: ChatCompletionParams, next: (p: ChatCompletionParams, context?: any) => Promise<any>, ) => Promise<any>; class MCPClient { private client: Client; private connected: boolean = false; transport: MCPTransport | null = null; // Public for cleanup private tools: any[] = []; private toolsLoadAttempted: boolean = false; private userMessages: Message[] = []; private assistantMessages: Message[] = []; private toolResponses: Record<string, Message> = {}; private errorCount: number = 0; private reconnecting: boolean = false; private cfg: InternalMCPConfig; private cleanupTimeout: NodeJS.Timeout | null = null; private isDisconnecting: boolean = false; // Flag to prevent concurrent disconnects static encTok = (() => { try { return require("tiktoken").encoding_for_model("gpt-4"); } catch { return { encode: (s: string) => new Array(Math.ceil((s || "").length / 4)).fill(0), }; } })(); constructor(cfg: MCPConfig = {}) { log(LVL.debug, "Initializing MCP Client"); if (cfg.logLevel !== undefined && LVL[cfg.logLevel] !== undefined) { setMcpLogLevel(cfg.logLevel); } // Handle both serverUrl and serverUrls const raw = cfg.serverUrls || (cfg.serverUrl ? [cfg.serverUrl] : null) || process.env.MCP_SERVER_URLS || process.env.MCP_SERVER_URL || "http://0.0.0.0:3000/mcp"; let urls = Array.isArray(raw) ? raw : String(raw).split(","); urls = urls .map((u: string) => u.trim()) .filter(Boolean) .filter((u: string) => /^https?:\/\//i.test(u)); if (!urls.length) urls = ["http://0.0.0.0:3000/mcp"]; this.cfg = { serverUrls: urls, headers: cfg.headers || {}, secondPassSystemPrompt: cfg.finalResponseSystemPrompt || cfg.secondPassSystemPrompt || "Provide a helpful answer based on the tool results, addressing the user's original question.", modelName: cfg.modelName || "gpt-4", maxOutputTokens: cfg.maxOutputTokens ?? 4096, maxToolCalls: cfg.maxToolCalls ?? 15, toolTimeoutSec: cfg.toolTimeoutSec ?? 60, disconnectAfterUse: cfg.disconnectAfterUse ?? true, connectionTimeoutMs: cfg.connectionTimeoutMs ?? 5_000, maxMessageGroups: cfg.maxMessageGroups ?? 3, tokenRateLimit: cfg.tokenRateLimit ?? 29_000, rateLimitWindowMs: cfg.rateLimitWindowMs ?? 60_000, noWaitOnTpm: cfg.noWaitOnTpm ?? false, forceCleanupTimeoutMs: cfg.forceCleanupTimeoutMs ?? 30_000, // 30 seconds timeout by default }; this.client = new Client({ name: "mcp-client", version: "0.1.0" }); log(LVL.info, `MCP Client initialized with model: ${this.cfg.modelName}`); // Register this client for tracking activeConnections.add(this); // Set up forced cleanup timeout if disconnectAfterUse is true if (this.cfg.disconnectAfterUse && this.cfg.forceCleanupTimeoutMs > 0) { this.scheduleForceCleanup(); } } // Schedule a forced cleanup private scheduleForceCleanup() { if (this.cleanupTimeout) { clearTimeout(this.cleanupTimeout); } this.cleanupTimeout = setTimeout(() => { log(LVL.warn, "Forced cleanup timeout triggered - disconnecting"); this.disconnect().catch((e) => { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error during forced disconnect: ${errorMessage}`); }); }, this.cfg.forceCleanupTimeoutMs); } async connect() { if (this.connected) { log(LVL.debug, "Already connected, skipping connect"); return; } log( LVL.info, `Connecting to MCP servers: ${this.cfg.serverUrls.join(", ")}`, ); let lastErr; for (const url of this.cfg.serverUrls) { try { log(LVL.debug, `Attempting connection to ${url}`); const transport = new SSEClientTransport(new URL(url), { requestInit: Object.keys(this.cfg.headers).length ? { headers: this.cfg.headers } : undefined, }) as MCPTransport; await Promise.race([ this.client.connect(transport), new Promise((_, rej) => setTimeout( rej, this.cfg.connectionTimeoutMs, new Error("Connection timeout"), ), ), ]); if (transport.eventSource) { transport.eventSource.onerror = (ev: Event) => this.#onSSEError(ev, url); } this.transport = transport; this.connected = true; this.client.setNotificationHandler( ToolListChangedNotificationSchema, () => { log(LVL.info, "Tool list changed – refreshing"); this.updateTools(); }, ); await this.updateTools(); log(LVL.info, `Connected to MCP ${url}. Tools: ${this.tools.length}`); // Reset the cleanup timeout when we successfully connect if (this.cfg.disconnectAfterUse) { this.scheduleForceCleanup(); } return; } catch (e: unknown) { lastErr = e; const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.warn, `Connect failed (${url}) – ${errorMessage}`); } } throw lastErr ?? new Error("MCP: all server URLs failed"); } #onSSEError(ev: Event, url: string) { this.errorCount++; const errorMessage = (ev as any)?.message ?? String(ev); log( LVL.warn, `SSE error (${url}): ${errorMessage}. count=${this.errorCount}`, ); if (this.errorCount > 3) { this.tools = []; return; } if (this.reconnecting) return; this.reconnecting = true; setTimeout(async () => { this.reconnecting = false; try { await this.updateTools(); } catch (e: unknown) { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.warn, `Reconnect refresh failed: ${errorMessage}`); } }, 1_000); } async disconnect() { // Prevent concurrent disconnect operations if (this.isDisconnecting) { log(LVL.debug, "Disconnect already in progress, skipping"); return; } this.isDisconnecting = true; // Clear the cleanup timeout if it exists if (this.cleanupTimeout) { clearTimeout(this.cleanupTimeout); this.cleanupTimeout = null; } if (this.connected) { log(LVL.info, "Disconnecting from MCP"); try { // 1. EventSource cleanup - closes the actual socket if (this.transport?.eventSource) { log(LVL.debug, "Closing SSE EventSource connection"); try { // Remove all event handlers first to prevent callbacks during close this.transport.eventSource.onmessage = null; this.transport.eventSource.onerror = null; this.transport.eventSource.onopen = null; // Then close the connection this.transport.eventSource.close(); this.transport.eventSource = undefined; log(LVL.debug, "SSE EventSource closed successfully"); } catch (sseError) { const errorMessage = sseError instanceof Error ? sseError.message : String(sseError); log(LVL.warn, `Error closing SSE connection: ${errorMessage}`); // Continue with other cleanup steps } } // 2. Transport cleanup if (this.transport) { log(LVL.debug, "Closing transport connection"); try { await Promise.race([ this.transport.close(), new Promise((resolve) => setTimeout(resolve, 1000)), // Timeout after 1 second ]); log(LVL.debug, "Transport closed successfully"); } catch (transportError) { const errorMessage = transportError instanceof Error ? transportError.message : String(transportError); log(LVL.warn, `Error closing transport: ${errorMessage}`); } } } catch (e: unknown) { // Catch any unexpected errors in the overall process const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Unexpected error during disconnect: ${errorMessage}`); } finally { // Always clean up state, regardless of errors this.connected = false; this.transport = null; activeConnections.delete(this); this.isDisconnecting = false; log(LVL.info, "Disconnect complete, all resources released"); } } else { // If not connected, just ensure we're removed from active connections activeConnections.delete(this); this.isDisconnecting = false; log(LVL.debug, "No active connection to disconnect"); } } async updateTools() { if (!this.connected) { log(LVL.debug, "Not connected, skipping tool update"); return (this.tools = []); } log(LVL.debug, "Updating MCP tools list"); this.toolsLoadAttempted = true; try { const { tools = [] } = (await this.client.listTools()) || {}; this.tools = tools.map((t: any) => ({ name: t.name, description: t.description || `Use ${t.name}`, input_schema: t.inputSchema, categories: (t.categories || []).map((c: string) => c.toLowerCase()), })); log(LVL.info, `Loaded ${this.tools.length} tools`); } catch (e: unknown) { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.warn, `listTools failed: ${errorMessage}`); } } formatTool(t: any) { let schema = typeof t.input_schema === "string" ? JSON.parse(t.input_schema) : t.input_schema || {}; if (schema?.type !== "object") schema = { type: "object", properties: {} }; if (!Object.keys(schema.properties).length) schema.properties = { query: { type: "string", description: `Input for ${t.name}` }, }; return { type: "function", function: { name: t.name, description: t.description, parameters: schema, }, }; } openAITools() { if (!this.tools.length && !this.toolsLoadAttempted) { log(LVL.debug, "No tools loaded, attempting to update tools"); this.updateTools().catch(() => {}); } return this.tools.map((t: any) => this.formatTool(t)); } buildMsgs() { const out = [...this.userMessages]; for (const m of this.assistantMessages) { out.push(m); m.tool_calls?.forEach((tc: any) => { const r = this.toolResponses[tc.id]; if (r) out.push(r); }); } return out; } trim(msgs: Message[]) { if (msgs.length <= 4) return msgs; const groups: Message[][] = []; const cur: Message[] = []; const flush = () => { if (cur.length) groups.push(cur.splice(0)); }; for (let i = 0; i < msgs.length; i++) { cur.push(msgs[i]); if (msgs[i].role === "assistant" && msgs[i].tool_calls?.length) { const ids = new Set(msgs[i].tool_calls?.map((t: any) => t.id) || []); for (let j = i + 1; j < msgs.length && msgs[j].role === "tool"; j++) { if (ids.has(msgs[j].tool_call_id!)) { cur.push(msgs[j]); i = j; } else break; } } flush(); } return [ groups[0] || [], ...groups.slice(-this.cfg.maxMessageGroups), ].flat(); } async processToolCall(tc: any) { let args; try { args = JSON.parse(tc.function.arguments); } catch { args = tc.function.arguments; } log(LVL.info, `Processing tool call: ${tc.function.name}`); try { const r = await this.client.callTool( { name: tc.function.name, arguments: args }, CallToolResultSchema, { timeout: this.cfg.toolTimeoutSec * 1_000 }, ); const txt = Array.isArray(r.content) ? r.content.map((c: any) => c.text).join("\n\n") : r.content || "No result"; return (this.toolResponses[tc.id] = { role: "tool", tool_call_id: tc.id, name: tc.function.name, content: typeof txt === "string" ? txt.length > 8_000 ? txt.slice(0, 8_000) + "\n\n[truncated]" : txt : "No result", }); } catch (e: unknown) { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.warn, `Tool error: ${errorMessage}`); return (this.toolResponses[tc.id] = { role: "tool", tool_call_id: tc.id, name: tc.function.name, content: `Error: ${errorMessage}`, }); } } getUserMessages() { return this.userMessages; } getAssistantMessages() { return this.assistantMessages; } getTools() { return this.tools; } isToolsLoadAttempted() { return this.toolsLoadAttempted; } getConfig() { return this.cfg; } } const multiModelPlugin: InternalPlugin = { name: "multiModelPlugin", async handle(params, next) { log( LVL.debug, `MultiModel plugin handling request for model: ${params.model}`, ); return providerFor(params.model).chat.completions.create(params as any); }, }; type PluginRegistry = { [key: string]: (config: any) => InternalPlugin; }; const PLUGIN_REGISTRY: PluginRegistry = { mcp: (config: any) => mcpPlugin(config), multiModel: () => multiModelPlugin, }; interface EnhancedCompletionParams extends ChatCompletionParams { return_tool_calls?: boolean; } function mcpPlugin(opts: any = {}): InternalPlugin { return { name: "mcpPlugin", async handle(params: ChatCompletionParams, next: PluginHandler) { log(LVL.info, `MCP plugin handling request for model: ${params.model}`); const serverConfig = opts.serverUrls || opts.serverUrl || process.env.MCP_SERVER_URLS || process.env.MCP_SERVER_URL; if (!serverConfig) { log(LVL.debug, "No MCP server config found, skipping MCP processing"); return (next as any)(params, undefined); } const wantStream = params.stream === true; log(LVL.debug, `Request stream mode: ${wantStream}`); const originalSystemMessage = params.messages.find( (m: any) => m.role === "system", ); log(LVL.debug, "Creating MCP client"); const mcp = new MCPClient({ ...opts, serverUrls: opts.serverUrls, modelName: params.model, maxOutputTokens: params.max_tokens, // Force disconnectAfterUse to true always, regardless of user config // This ensures automatic cleanup without requiring client app changes disconnectAfterUse: true, }); try { log(LVL.debug, "Connecting to MCP"); await mcp.connect(); } catch (e: unknown) { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.warn, `MCP unavailable – ${errorMessage}`); // Ensure we clean up even if connect fails await mcp.disconnect().catch(() => {}); return (next as any)(params, undefined); } try { log(LVL.debug, `Processing ${params.messages.length} messages`); params.messages.forEach((m: any) => (m.role === "user" ? mcp.getUserMessages() : mcp.getAssistantMessages() ).push(m), ); if (!mcp.getTools().length && !mcp.isToolsLoadAttempted()) await mcp.updateTools(); const tools = mcp.openAITools(); log(LVL.debug, `Available tools: ${tools.length}`); const messagesWithSystem = mcp.trim(mcp.buildMsgs()); const firstPassMessages = originalSystemMessage ? [ originalSystemMessage, ...messagesWithSystem.filter((m: any) => m.role !== "system"), ] : messagesWithSystem; log(LVL.info, "Sending first pass request to model"); const first = await (next as any)( { model: params.model, stream: false, max_tokens: params.max_tokens ?? 4096, messages: firstPassMessages as any, ...(tools.length && { tools, tool_choice: "auto" }), }, undefined, ); const assistant = first.choices[0].message; const calls = assistant.tool_calls ?? []; log( LVL.info, `First pass response received, tool calls: ${calls.length}`, ); if (!calls.length) { log(LVL.debug, "No tool calls, returning direct response"); // Always disconnect MCP client when no tool calls are needed log(LVL.debug, "Disconnecting MCP client (no tool calls)"); await mcp.disconnect().catch((e) => { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error disconnecting: ${errorMessage}`); }); if (wantStream) { const raw = await (next as any)( { ...params, stream: true }, undefined, ); return asIterable(raw); } return first; } mcp.getAssistantMessages().push({ role: "assistant", content: assistant.content, tool_calls: calls, }); const summaries = []; log( LVL.info, `Processing ${Math.min(calls.length, mcp.getConfig().maxToolCalls)} tool calls`, ); for (const tc of calls.slice(0, mcp.getConfig().maxToolCalls)) summaries.push( `### ${tc.function.name}\n${(await mcp.processToolCall(tc)).content}`, ); const finalResponseSystemPrompt = opts.finalResponseSystemPrompt || opts.secondPassSystemPrompt || mcp.getConfig().secondPassSystemPrompt; // Create a cleanup promise that will be used to disconnect after streaming is done let cleanupResolver = () => {}; // Initialize with a no-op function const cleanupPromise = new Promise<void>((resolve) => { cleanupResolver = resolve; }); // Set up a timer to force cleanup if needed const forceCleanupTimer = setTimeout(() => { log(LVL.warn, "Force cleanup timer triggered for streaming response"); cleanupResolver(); }, 60000); // 60 second fallback log(LVL.info, "Sending follow-up request with tool results"); const followParams = { model: params.model, stream: wantStream, max_tokens: params.max_tokens ?? 4096, messages: [ { role: "system", content: finalResponseSystemPrompt }, { role: "user", content: mcp.getUserMessages().at(-1)?.content || "", }, { role: "user", content: summaries.join("\n\n") }, ], }; // Create a wrapper for streaming responses to handle cleanup if (wantStream) { log(LVL.debug, "Setting up stream response with auto-cleanup"); // Get the raw follow-up response const rawFollow = await (next as any)(followParams, undefined); // Create the wrapped iterator that will handle cleanup const wrappedIterator = (async function* () { try { // Use proper async iteration with cleanup if (rawFollow && Symbol.asyncIterator in rawFollow) { try { for await (const chunk of rawFollow) { yield chunk; } } catch (e) { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error in stream iteration: ${errorMessage}`); throw e; } finally { // Clean up when streaming is done log(LVL.debug, "Stream completed, disconnecting MCP"); await mcp.disconnect().catch((e) => { const errorMessage = e instanceof Error ? e.message : String(e); log( LVL.error, `Error disconnecting after stream: ${errorMessage}`, ); }); cleanupResolver(); if (forceCleanupTimer) { clearTimeout(forceCleanupTimer); } } } else { // Handle non-iterator response const content = rawFollow?.choices?.[0]?.message?.content ?? rawFollow?.choices?.[0]?.delta?.content ?? rawFollow?.content ?? JSON.stringify(rawFollow); yield { choices: [{ delta: { content } }] }; // Clean up after yielding log( LVL.debug, "Non-stream response complete, disconnecting MCP", ); await mcp.disconnect().catch((e) => { const errorMessage = e instanceof Error ? e.message : String(e); log( LVL.error, `Error disconnecting after response: ${errorMessage}`, ); }); cleanupResolver(); if (forceCleanupTimer) { clearTimeout(forceCleanupTimer); } } } catch (error) { // Clean up even on error const errorMessage = error instanceof Error ? error.message : String(error); log(LVL.error, `Stream error: ${errorMessage}`); await mcp.disconnect().catch(() => {}); cleanupResolver(); if (forceCleanupTimer) { clearTimeout(forceCleanupTimer); } throw error; } })(); // Wait for cleanup in the background (failsafe) cleanupPromise.then(() => { mcp.disconnect().catch(() => {}); }); return wrappedIterator; } // For non-streaming responses, the process is simpler try { log(LVL.debug, "Processing non-streaming follow-up response"); const follow = await (next as any)(followParams, undefined); // Process the follow-up response log(LVL.debug, "Processing final response"); let final = ""; if (follow && Symbol.asyncIterator in follow) { for await (const ch of follow) { final += ch.choices?.[0]?.delta?.content || ""; } } else if (follow?.choices?.[0]?.message?.content) { final = follow.choices[0].message.content; } assistant.content = final; // Always disconnect after processing log(LVL.debug, "Response complete, disconnecting MCP"); await mcp.disconnect().catch((e) => { const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error disconnecting: ${errorMessage}`); }); cleanupResolver(); if (forceCleanupTimer) { clearTimeout(forceCleanupTimer); } log(LVL.info, "Request completed successfully"); const paramsWithToolCalls = params as EnhancedCompletionParams; return { id: `chatcmpl-${Date.now()}`, object: "chat.completion", created: Math.floor(Date.now() / 1e3), model: params.model, usage: first.usage, choices: [ { index: 0, finish_reason: calls.length && paramsWithToolCalls.return_tool_calls ? "tool_calls" : "stop", message: { role: "assistant", content: assistant.content, tool_calls: paramsWithToolCalls.return_tool_calls ? calls : undefined, }, }, ], }; } catch (e) { // Clean up in error case const errorMessage = e instanceof Error ? e.message : String(e); log(LVL.error, `Error in follow-up response: ${errorMessage}`); await mcp.disconnect().catch(() => {}); cleanupResolver(); if (forceCleanupTimer) { clearTimeout(forceCleanupTimer); } throw e; } } catch (e) { // Always ensure connection is closed in case of errors log(LVL.debug, "Error in MCP plugin handler, disconnecting"); await mcp.disconnect().catch((err) => { const errorMessage = err instanceof Error ? err.message : String(err); log(LVL.error, `Error disconnecting after error: ${errorMessage}`); }); throw e; } }, }; } function compose( plugins: InternalPlugin[], base: PluginHandler, ): PluginHandler { return plugins.reduceRight( (next, plugin) => (params) => plugin.handle(params, next), base, ); } type ChatCompletionsCreate = typeof OriginalOpenAI.prototype.chat.completions.create; // Explicitly include all required fields from the OpenAI SDK export interface OpenAIOptions { apiKey?: string; organization?: string; baseURL?: string; timeout?: number; maxRetries?: number; defaultQuery?: Record<string, string>; defaultHeaders?: Record<string, string>; dangerouslyAllowBrowser?: boolean; plugins?: string[] | Plugin[] | string | null; pluginConfig?: Record<string, any>; mcp?: MCPConfig; mcpLogLevel?: MCP_LogLevel; } // Main OpenAI class implementation class OpenAI extends OriginalOpenAI { constructor(options: OpenAIOptions = {}) { // Pass all standard OpenAI options to the parent constructor super({ apiKey: options.apiKey, organization: options.organization, baseURL: options.baseURL, timeout: options.timeout, maxRetries: options.maxRetries, defaultQuery: options.defaultQuery, defaultHeaders: options.defaultHeaders, dangerouslyAllowBrowser: options.dangerouslyAllowBrowser, }); if (options.mcpLogLevel || (options.mcp && options.mcp.logLevel)) { setMcpLogLevel(options.mcpLogLevel || options.mcp?.logLevel || "debug"); } log(LVL.info, `Initializing OpenAI client with plugins`); globalApiKey = options.apiKey || null; // Configure MCP plugin with mcp config if present const pluginConfig = { ...(options.pluginConfig || {}), mcp: options.mcp || {}, }; const activePlugins = this.#loadPlugins( options.plugins || null, pluginConfig, ); const originalCreate = this.chat.completions.create.bind( this.chat.completions, ) as ChatCompletionsCreate; const handler = compose(activePlugins, (p) => originalCreate(p as any)); this.chat.completions.create = handler as unknown as ChatCompletionsCreate; log( LVL.info, `OpenAI client initialized with ${activePlugins.length} plugins`, ); } #loadPlugins( plugins: string | string[] | Plugin[] | null, config: Record<string, any>, ): InternalPlugin[] { // Handle case of Plugin objects directly if ( Array.isArray(plugins) && plugins.length > 0 && typeof plugins[0] === "object" ) { // Convert to internal plugin format const pluginObjects = plugins as Plugin[]; const uniquePluginNames = pluginObjects .map((p) => p.name) .filter((name, index, self) => self.indexOf(name) === index); log( LVL.debug, `Loading ${uniquePluginNames.length} object plugins: ${uniquePluginNames.join(", ")}`, ); // Adapter to convert external Plugin to internal plugin format return pluginObjects.map((p) => ({ name: p.name, handle: async (params: any, next: any) => { return p.handle(params, next); }, })); } // Handle case of string plugin names else if (plugins) { const pluginNames: string[] = Array.isArray(plugins) ? Array.from(new Set(plugins as string[])) : [plugins as string]; log( LVL.debug, `Loading ${pluginNames.length} plugins: ${pluginNames.join(", ")}`, ); const pluginMap = new Map<string, InternalPlugin>(); pluginNames.forEach((name) => { if (name === "mcp" && !pluginMap.has("mcp")) { pluginMap.set("mcp", mcpPlugin(config.mcp || {})); } if (name === "multiModel" && !pluginMap.has("multiModel")) { pluginMap.set("multiModel", multiModelPlugin); } }); return Array.from(pluginMap.values()); } // Default to empty plugins array return []; } } // Export both as default and named - this is critical for it to work correctly export default OpenAI; export { OpenAI }; // For CommonJS compatibility if (typeof module !== "undefined") { module.exports = OpenAI; module.exports.OpenAI = OpenAI; module.exports.default = OpenAI; }