UNPKG

openai-plugins

Version:

A TypeScript library that provides an OpenAI-compatible client for the Model Context Protocol (MCP).

576 lines (575 loc) 23.1 kB
var __classPrivateFieldGet = (this && this.__classPrivateFieldGet) || function (receiver, state, kind, f) { if (kind === "a" && !f) throw new TypeError("Private accessor was defined without a getter"); if (typeof state === "function" ? receiver !== state || !f : !state.has(receiver)) throw new TypeError("Cannot read private member from an object whose class did not declare it"); return kind === "m" ? f : kind === "a" ? f.call(receiver) : f ? f.value : state.get(receiver); }; var _MCPClient_instances, _MCPClient_onSSEError, _OpenAI_instances, _OpenAI_loadPlugins; import OriginalOpenAI from "openai"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; import { CallToolResultSchema, ToolListChangedNotificationSchema, } from "@modelcontextprotocol/sdk/types.js"; import { EventSource } from "eventsource"; if (typeof globalThis.EventSource === "undefined") { globalThis.EventSource = EventSource; } // Improved logging system with export to allow external configuration const LVL = { debug: 0, info: 1, warn: 2, error: 3 }; // Export LOG_LVL to allow external configuration export let MCP_LOG_LVL = LVL["debug"]; // Track current log level to avoid duplicate log messages export const setMcpLogLevel = (level) => { if (LVL[level] !== undefined) { // Only log if the level is actually changing const currentLevel = Object.keys(LVL).find((key) => LVL[key] === MCP_LOG_LVL); const isChanging = currentLevel !== level; MCP_LOG_LVL = LVL[level]; if (isChanging) { log(LVL.info, `MCP log level set to ${level.toUpperCase()}`); } } else { log(LVL.warn, `Invalid MCP log level: ${level}. Using current level.`); } }; // Improved log function that consistently outputs to console export const log = (lvl, msg) => { // Only log if the level is greater than or equal to MCP_LOG_LVL if (lvl < MCP_LOG_LVL) return; const tag = ["DEBUG", "INFO", "WARN", "ERROR"][lvl]; const logMsg = `[${new Date().toISOString()}] [MCP] [${tag}] ${msg}`; // Always use console.log for visibility console.log(logMsg); // Add to specific log levels for filtering if (lvl >= LVL.error) { console.error(logMsg); } else if (lvl >= LVL.warn) { console.warn(logMsg); } else if (lvl === LVL.debug) { console.debug(logMsg); } }; async function* asIterable(resp) { if (resp && Symbol.asyncIterator in resp) { for await (const x of resp) yield x; return; } const content = resp?.choices?.[0]?.message?.content ?? resp?.choices?.[0]?.delta?.content ?? resp?.content ?? JSON.stringify(resp); yield { choices: [{ delta: { content } }] }; } const PROVIDERS = [ { name: "openai", regex: /^(gpt|text-|davinci|curie|babbage|ada|dall-e)/i, baseURL: "https://api.openai.com/v1", keyEnv: "OPENAI_API_KEY", }, { name: "anthropic", regex: /^claude/i, baseURL: "https://api.anthropic.com/v1", keyEnv: "ANTHROPIC_API_KEY", }, { name: "gemini", regex: /^gemini/i, baseURL: "https://generativelanguage.googleapis.com/v1beta/openai", keyEnv: "GEMINI_API_KEY", }, ]; let globalApiKey = null; const providerCache = new Map(); function providerFor(model) { const info = PROVIDERS.find((p) => p.regex.test(model)) || PROVIDERS[0]; log(LVL.debug, `Using provider ${info.name} for model ${model}`); if (!providerCache.has(info.name)) { log(LVL.debug, `Creating new provider instance for ${info.name}`); providerCache.set(info.name, new OriginalOpenAI({ apiKey: globalApiKey || process.env[info.keyEnv], baseURL: info.baseURL, })); } return providerCache.get(info.name); } class MCPClient { constructor(cfg = {}) { _MCPClient_instances.add(this); this.connected = false; this.transport = null; this.tools = []; this.toolsLoadAttempted = false; this.userMessages = []; this.assistantMessages = []; this.toolResponses = {}; this.errorCount = 0; this.reconnecting = false; log(LVL.debug, "Initializing MCP Client"); if (cfg.logLevel !== undefined && LVL[cfg.logLevel] !== undefined) { setMcpLogLevel(cfg.logLevel); } const raw = cfg.serverUrls ?? process.env.MCP_SERVER_URLS ?? process.env.MCP_SERVER_URL ?? "http://0.0.0.0:3000/mcp"; let urls = Array.isArray(raw) ? raw : String(raw).split(","); urls = urls .map((u) => u.trim()) .filter(Boolean) .filter((u) => /^https?:\/\//i.test(u)); if (!urls.length) urls = ["http://0.0.0.0:3000/mcp"]; this.cfg = { serverUrls: urls, headers: cfg.headers || {}, secondPassSystemPrompt: cfg.finalResponseSystemPrompt || cfg.secondPassSystemPrompt || "Provide a helpful answer based on the tool results, addressing the user's original question.", modelName: cfg.modelName || "gpt-4", maxOutputTokens: cfg.maxOutputTokens ?? 4096, maxToolCalls: cfg.maxToolCalls ?? 15, toolTimeoutSec: cfg.toolTimeoutSec ?? 60, disconnectAfterUse: cfg.disconnectAfterUse ?? true, connectionTimeoutMs: cfg.connectionTimeoutMs ?? 5000, maxMessageGroups: cfg.maxMessageGroups ?? 3, tokenRateLimit: cfg.tokenRateLimit ?? 29000, rateLimitWindowMs: cfg.rateLimitWindowMs ?? 60000, noWaitOnTpm: cfg.noWaitOnTpm ?? false, }; this.client = new Client({ name: "mcp-client", version: "0.1.0" }); log(LVL.info, `MCP Client initialized with model: ${this.cfg.modelName}`); } async connect() { if (this.connected) { log(LVL.debug, "Already connected, skipping connect"); return; } log(LVL.info, `Connecting to MCP servers: ${this.cfg.serverUrls.join(", ")}`); let lastErr; for (const url of this.cfg.serverUrls) { try { log(LVL.debug, `Attempting connection to ${url}`); const transport = new SSEClientTransport(new URL(url), { requestInit: Object.keys(this.cfg.headers).length ? { headers: this.cfg.headers } : undefined, }); await Promise.race([ this.client.connect(transport), new Promise((_, rej) => setTimeout(rej, this.cfg.connectionTimeoutMs, new Error("Connection timeout"))), ]); if (transport.eventSource) { transport.eventSource.onerror = (ev) => __classPrivateFieldGet(this, _MCPClient_instances, "m", _MCPClient_onSSEError).call(this, ev, url); } this.transport = transport; this.connected = true; this.client.setNotificationHandler(ToolListChangedNotificationSchema, () => { log(LVL.info, "Tool list changed – refreshing"); this.updateTools(); }); await this.updateTools(); log(LVL.info, `Connected to MCP ${url}. Tools: ${this.tools.length}`); return; } catch (e) { lastErr = e; log(LVL.warn, `Connect failed (${url}) – ${e?.message || String(e)}`); } } throw lastErr ?? new Error("MCP: all server URLs failed"); } async disconnect() { if (this.connected && this.transport) { log(LVL.info, "Disconnecting from MCP"); try { // Close underlying SSE connection if present this.transport.eventSource?.close(); await this.transport.close(); } catch (e) { log(LVL.warn, `Error during disconnect: ${e.message}`); } finally { this.connected = false; this.transport = null; } } } async updateTools() { if (!this.connected) { log(LVL.debug, "Not connected, skipping tool update"); return (this.tools = []); } log(LVL.debug, "Updating MCP tools list"); this.toolsLoadAttempted = true; try { const { tools = [] } = (await this.client.listTools()) || {}; this.tools = tools.map((t) => ({ name: t.name, description: t.description || `Use ${t.name}`, input_schema: t.inputSchema, categories: (t.categories || []).map((c) => c.toLowerCase()), })); log(LVL.info, `Loaded ${this.tools.length} tools`); } catch (e) { log(LVL.warn, `listTools failed: ${e?.message || String(e)}`); } } formatTool(t) { let schema = typeof t.input_schema === "string" ? JSON.parse(t.input_schema) : t.input_schema || {}; if (schema?.type !== "object") schema = { type: "object", properties: {} }; if (!Object.keys(schema.properties).length) schema.properties = { query: { type: "string", description: `Input for ${t.name}` }, }; return { type: "function", function: { name: t.name, description: t.description, parameters: schema, }, }; } openAITools() { if (!this.tools.length && !this.toolsLoadAttempted) { log(LVL.debug, "No tools loaded, attempting to update tools"); this.updateTools().catch(() => { }); } return this.tools.map((t) => this.formatTool(t)); } buildMsgs() { const out = [...this.userMessages]; for (const m of this.assistantMessages) { out.push(m); m.tool_calls?.forEach((tc) => { const r = this.toolResponses[tc.id]; if (r) out.push(r); }); } return out; } trim(msgs) { if (msgs.length <= 4) return msgs; const groups = []; const cur = []; const flush = () => { if (cur.length) groups.push(cur.splice(0)); }; for (let i = 0; i < msgs.length; i++) { cur.push(msgs[i]); if (msgs[i].role === "assistant" && msgs[i].tool_calls?.length) { const ids = new Set(msgs[i].tool_calls?.map((t) => t.id) || []); for (let j = i + 1; j < msgs.length && msgs[j].role === "tool"; j++) { if (ids.has(msgs[j].tool_call_id)) { cur.push(msgs[j]); i = j; } else break; } } flush(); } return [ groups[0] || [], ...groups.slice(-this.cfg.maxMessageGroups), ].flat(); } async processToolCall(tc) { let args; try { args = JSON.parse(tc.function.arguments); } catch { args = tc.function.arguments; } log(LVL.info, `Processing tool call: ${tc.function.name}`); try { const r = await this.client.callTool({ name: tc.function.name, arguments: args }, CallToolResultSchema, { timeout: this.cfg.toolTimeoutSec * 1000 }); const txt = Array.isArray(r.content) ? r.content.map((c) => c.text).join("\n\n") : r.content || "No result"; return (this.toolResponses[tc.id] = { role: "tool", tool_call_id: tc.id, name: tc.function.name, content: typeof txt === "string" ? txt.length > 8000 ? txt.slice(0, 8000) + "\n\n[truncated]" : txt : "No result", }); } catch (e) { log(LVL.warn, `Tool error: ${e?.message || String(e)}`); return (this.toolResponses[tc.id] = { role: "tool", tool_call_id: tc.id, name: tc.function.name, content: `Error: ${e?.message || String(e)}`, }); } } getUserMessages() { return this.userMessages; } getAssistantMessages() { return this.assistantMessages; } getTools() { return this.tools; } isToolsLoadAttempted() { return this.toolsLoadAttempted; } getConfig() { return this.cfg; } } _MCPClient_instances = new WeakSet(), _MCPClient_onSSEError = function _MCPClient_onSSEError(ev, url) { this.errorCount++; log(LVL.warn, `SSE error (${url}): ${ev?.message ?? String(ev)}. count=${this.errorCount}`); if (this.errorCount > 3) { this.tools = []; return; } if (this.reconnecting) return; this.reconnecting = true; setTimeout(async () => { this.reconnecting = false; try { await this.updateTools(); } catch (e) { log(LVL.warn, `Reconnect refresh failed: ${e?.message || String(e)}`); } }, 1000); }; MCPClient.encTok = (() => { try { return require("tiktoken").encoding_for_model("gpt-4"); } catch { return { encode: (s) => new Array(Math.ceil((s || "").length / 4)).fill(0), }; } })(); const multiModelPlugin = { name: "multiModelPlugin", async handle(params, next) { log(LVL.debug, `MultiModel plugin handling request for model: ${params.model}`); return providerFor(params.model).chat.completions.create(params); }, }; const PLUGIN_REGISTRY = { mcp: (config) => mcpPlugin(config), multiModel: () => multiModelPlugin, }; function mcpPlugin(opts = {}) { return { name: "mcpPlugin", async handle(params, next) { log(LVL.info, `MCP plugin handling request for model: ${params.model}`); const serverConfig = opts.serverUrls || opts.serverUrl || process.env.MCP_SERVER_URLS || process.env.MCP_SERVER_URL; if (!serverConfig) { log(LVL.debug, "No MCP server config found, skipping MCP processing"); return next(params, undefined); } const wantStream = params.stream === true; log(LVL.debug, `Request stream mode: ${wantStream}`); const originalSystemMessage = params.messages.find((m) => m.role === "system"); log(LVL.debug, "Creating MCP client"); const mcp = new MCPClient({ ...opts, serverUrls: opts.serverUrls, modelName: params.model, maxOutputTokens: params.max_tokens, }); try { log(LVL.debug, "Connecting to MCP"); await mcp.connect(); } catch (e) { log(LVL.warn, `MCP unavailable – ${e?.message || String(e)}`); return next(params, undefined); } log(LVL.debug, `Processing ${params.messages.length} messages`); params.messages.forEach((m) => (m.role === "user" ? mcp.getUserMessages() : mcp.getAssistantMessages()).push(m)); if (!mcp.getTools().length && !mcp.isToolsLoadAttempted()) await mcp.updateTools(); const tools = mcp.openAITools(); log(LVL.debug, `Available tools: ${tools.length}`); const messagesWithSystem = mcp.trim(mcp.buildMsgs()); const firstPassMessages = originalSystemMessage ? [ originalSystemMessage, ...messagesWithSystem.filter((m) => m.role !== "system"), ] : messagesWithSystem; log(LVL.info, "Sending first pass request to model"); const first = await next({ model: params.model, stream: false, max_tokens: params.max_tokens ?? 4096, messages: firstPassMessages, ...(tools.length && { tools, tool_choice: "auto" }), }, undefined); const assistant = first.choices[0].message; const calls = assistant.tool_calls ?? []; log(LVL.info, `First pass response received, tool calls: ${calls.length}`); if (!calls.length) { log(LVL.debug, "No tool calls, returning direct response"); if (wantStream) { const raw = await next({ ...params, stream: true }, undefined); opts.disconnectAfterUse && (await mcp.disconnect()); return asIterable(raw); } opts.disconnectAfterUse && (await mcp.disconnect()); return first; } mcp.getAssistantMessages().push({ role: "assistant", content: assistant.content, tool_calls: calls, }); const summaries = []; log(LVL.info, `Processing ${Math.min(calls.length, mcp.getConfig().maxToolCalls)} tool calls`); for (const tc of calls.slice(0, mcp.getConfig().maxToolCalls)) summaries.push(`### ${tc.function.name}\n${(await mcp.processToolCall(tc)).content}`); const finalResponseSystemPrompt = opts.finalResponseSystemPrompt || opts.secondPassSystemPrompt || mcp.getConfig().secondPassSystemPrompt; log(LVL.info, "Sending follow-up request with tool results"); const follow = await next({ model: params.model, stream: wantStream, max_tokens: params.max_tokens ?? 4096, messages: [ { role: "system", content: finalResponseSystemPrompt }, { role: "user", content: mcp.getUserMessages().at(-1)?.content || "", }, { role: "user", content: summaries.join("\n\n") }, ], }, undefined); if (wantStream) { log(LVL.debug, "Returning stream response"); try { opts.disconnectAfterUse && (await mcp.disconnect()); return asIterable(follow); } catch (e) { opts.disconnectAfterUse && (await mcp.disconnect()); throw e; } } log(LVL.debug, "Processing final response"); try { let final = ""; for await (const ch of asIterable(follow)) final += ch.choices?.[0]?.delta?.content || ""; assistant.content = final; opts.disconnectAfterUse && (await mcp.disconnect()); log(LVL.info, "Request completed successfully"); const paramsWithToolCalls = params; return { id: `chatcmpl-${Date.now()}`, object: "chat.completion", created: Math.floor(Date.now() / 1e3), model: params.model, usage: first.usage, choices: [ { index: 0, finish_reason: calls.length && paramsWithToolCalls.return_tool_calls ? "tool_calls" : "stop", message: { role: "assistant", content: assistant.content, tool_calls: paramsWithToolCalls.return_tool_calls ? calls : undefined, }, }, ], }; } catch (e) { opts.disconnectAfterUse && (await mcp.disconnect()); throw e; } }, }; } function compose(plugins, base) { return plugins.reduceRight((next, plugin) => (params) => plugin.handle(params, next), base); } class OpenAI extends OriginalOpenAI { constructor(opts) { super(opts); _OpenAI_instances.add(this); if (opts.mcpLogLevel) { setMcpLogLevel(opts.mcpLogLevel); } log(LVL.info, `Initializing OpenAI client with plugins`); globalApiKey = opts.apiKey || null; const activePlugins = __classPrivateFieldGet(this, _OpenAI_instances, "m", _OpenAI_loadPlugins).call(this, opts.plugins || null, opts.pluginConfig || {}); const originalCreate = this.chat.completions.create.bind(this.chat.completions); const handler = compose(activePlugins, (p) => originalCreate(p)); this.chat.completions.create = handler; log(LVL.info, `OpenAI client initialized with ${activePlugins.length} plugins`); } } _OpenAI_instances = new WeakSet(), _OpenAI_loadPlugins = function _OpenAI_loadPlugins(plugins, config) { if (Array.isArray(plugins) && plugins.length > 0 && typeof plugins[0] === "object") { const pluginObjects = plugins; const uniquePluginNames = pluginObjects .map((p) => p.name) .filter((name, index, self) => self.indexOf(name) === index); log(LVL.debug, `Loading ${uniquePluginNames.length} object plugins: ${uniquePluginNames.join(", ")}`); return pluginObjects; } else if (plugins) { const pluginNames = Array.isArray(plugins) ? Array.from(new Set(plugins)) : [plugins]; log(LVL.debug, `Loading ${pluginNames.length} plugins: ${pluginNames.join(", ")}`); const pluginMap = new Map(); pluginNames.forEach((name) => { if (name === "mcp" && !pluginMap.has("mcp")) { pluginMap.set("mcp", mcpPlugin(config.mcp || {})); } if (name === "multiModel" && !pluginMap.has("multiModel")) { pluginMap.set("multiModel", multiModelPlugin); } }); return Array.from(pluginMap.values()); } return []; }; // Export OpenAI as both default and named export for drop-in compatibility export { OpenAI as default, OpenAI };