openai-plugins
Version:
A TypeScript library that provides an OpenAI-compatible client for the Model Context Protocol (MCP).
1,217 lines (1,069 loc) • 37.8 kB
text/typescript
import OriginalOpenAI from "openai";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
import {
CallToolResultSchema,
ToolListChangedNotificationSchema,
} from "@modelcontextprotocol/sdk/types.js";
import { EventSource } from "eventsource";
declare global {
var EventSource: typeof EventSource;
}
if (typeof globalThis.EventSource === "undefined") {
globalThis.EventSource = EventSource;
}
// Use more specific types for better compatibility with OpenAI SDK
interface ChatCompletionRole {
role: "system" | "user" | "assistant" | "tool" | "function";
content: string;
tool_calls?: any[];
tool_call_id?: string;
name?: string;
}
interface ChatCompletionParams {
model: string;
messages: ChatCompletionRole[];
max_tokens?: number;
temperature?: number;
stream?: boolean;
tools?: any[];
tool_choice?: string | object;
response_format?: { type: string };
[key: string]: any; // Allow other properties
}
// Improved logging system with export to allow external configuration
const LVL = { debug: 0, info: 1, warn: 2, error: 3 } as const;
export type MCP_LogLevel = keyof typeof LVL;
type LogLevelValue = (typeof LVL)[MCP_LogLevel];
// Export LOG_LVL to allow external configuration
export let MCP_LOG_LVL: LogLevelValue = LVL["debug"];
// Track current log level to avoid duplicate log messages
export const setMcpLogLevel = (level: MCP_LogLevel): void => {
if (LVL[level] !== undefined) {
// Only log if the level is actually changing
const currentLevel = Object.keys(LVL).find(
(key) => LVL[key as MCP_LogLevel] === MCP_LOG_LVL,
) as MCP_LogLevel;
const isChanging = currentLevel !== level;
MCP_LOG_LVL = LVL[level] as LogLevelValue;
if (isChanging) {
log(LVL.info, `MCP log level set to ${level.toUpperCase()}`);
}
} else {
log(LVL.warn, `Invalid MCP log level: ${level}. Using current level.`);
}
};
// Improved log function that consistently outputs to console
export const log = (lvl: number, msg: string): void => {
// Only log if the level is greater than or equal to MCP_LOG_LVL
if (lvl < MCP_LOG_LVL) return;
const tag = ["DEBUG", "INFO", "WARN", "ERROR"][lvl];
const logMsg = `[${new Date().toISOString()}] [MCP] [${tag}] ${msg}`;
// Always use console.log for visibility
console.log(logMsg);
// Add to specific log levels for filtering
if (lvl >= LVL.error) {
console.error(logMsg);
} else if (lvl >= LVL.warn) {
console.warn(logMsg);
} else if (lvl === LVL.debug) {
console.debug(logMsg);
}
};
interface Provider {
name: string;
regex: RegExp;
baseURL: string;
keyEnv: string;
}
type Message = {
role: "system" | "user" | "assistant" | "tool";
content: string;
tool_calls?: ToolCall[];
tool_call_id?: string;
name?: string;
};
type ToolCall = {
id: string;
function: {
name: string;
arguments: string;
};
};
// Updated MCPConfig to match the interface structure
export interface MCPConfig {
serverUrl?: string;
serverUrls?: string[];
headers?: Record<string, string>;
maxToolCalls?: number;
toolTimeoutSec?: number;
disconnectAfterUse?: boolean;
connectionTimeoutMs?: number;
maxMessageGroups?: number;
finalResponseSystemPrompt?: string;
secondPassSystemPrompt?: string;
modelName?: string;
maxOutputTokens?: number;
tokenRateLimit?: number;
rateLimitWindowMs?: number;
noWaitOnTpm?: boolean;
logLevel?: MCP_LogLevel;
forceCleanupTimeoutMs?: number; // New option for forcing cleanup
}
// Internal MCPConfig used by MCPClient
interface InternalMCPConfig {
serverUrls: string[];
headers: Record<string, string>;
finalResponseSystemPrompt?: string;
secondPassSystemPrompt: string;
modelName: string;
maxOutputTokens: number;
maxToolCalls: number;
toolTimeoutSec: number;
disconnectAfterUse: boolean;
connectionTimeoutMs: number;
maxMessageGroups: number;
tokenRateLimit: number;
rateLimitWindowMs: number;
noWaitOnTpm: boolean;
logLevel?: MCP_LogLevel;
forceCleanupTimeoutMs: number; // New option for forcing cleanup
}
interface MCPTransport extends SSEClientTransport {
eventSource?: EventSource;
}
// Track active connections for cleanup
const activeConnections = new Set<MCPClient>();
// Function to force cleanup of all connections
const forceCleanupAllConnections = async () => {
log(
LVL.warn,
`Force cleaning up ${activeConnections.size} active connections`,
);
const disconnectPromises = Array.from(activeConnections).map((client) =>
client.disconnect().catch((e) => {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Error during forced disconnect: ${errorMessage}`);
}),
);
await Promise.allSettled(disconnectPromises);
activeConnections.clear();
};
// Add a global process exit handler to clean up connections
if (typeof process !== "undefined") {
process.on("exit", () => {
try {
// Synchronous cleanup on exit
Array.from(activeConnections).forEach((client) => {
try {
if (client.transport?.eventSource) {
client.transport.eventSource.close();
}
} catch (e) {
// Ignore errors during process exit
}
});
activeConnections.clear();
} catch (e) {
// Ignore errors during process exit
}
});
}
// Enhanced asIterable to support automatic cleanup of streams
async function* asIterable(resp: any, cleanup?: () => Promise<void>) {
// For async iterables, we iterate and handle cleanup
if (resp && Symbol.asyncIterator in resp) {
try {
for await (const x of resp) yield x;
} catch (e) {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Error in async iteration: ${errorMessage}`);
// Rethrow to allow proper handling upstream
throw e;
} finally {
// Run cleanup if provided
if (cleanup) {
try {
await cleanup();
} catch (e) {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Error in iteration cleanup: ${errorMessage}`);
}
}
}
return;
}
// For non-iterables, handle as before and run cleanup at the end
const content =
resp?.choices?.[0]?.message?.content ??
resp?.choices?.[0]?.delta?.content ??
resp?.content ??
JSON.stringify(resp);
yield { choices: [{ delta: { content } }] };
// Run cleanup if provided
if (cleanup) {
try {
await cleanup();
} catch (e) {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Error in non-iteration cleanup: ${errorMessage}`);
}
}
}
const PROVIDERS: Provider[] = [
{
name: "openai",
regex: /^(gpt|text-|davinci|curie|babbage|ada|dall-e)/i,
baseURL: "https://api.openai.com/v1",
keyEnv: "OPENAI_API_KEY",
},
{
name: "anthropic",
regex: /^claude/i,
baseURL: "https://api.anthropic.com/v1",
keyEnv: "ANTHROPIC_API_KEY",
},
{
name: "gemini",
regex: /^gemini/i,
baseURL: "https://generativelanguage.googleapis.com/v1beta/openai",
keyEnv: "GEMINI_API_KEY",
},
];
let globalApiKey: string | null = null;
const providerCache = new Map<string, OriginalOpenAI>();
function providerFor(model: string): OriginalOpenAI {
const info = PROVIDERS.find((p) => p.regex.test(model)) || PROVIDERS[0];
log(LVL.debug, `Using provider ${info.name} for model ${model}`);
if (!providerCache.has(info.name)) {
log(LVL.debug, `Creating new provider instance for ${info.name}`);
providerCache.set(
info.name,
new OriginalOpenAI({
apiKey: globalApiKey || process.env[info.keyEnv],
baseURL: info.baseURL,
}),
);
}
return providerCache.get(info.name)!;
}
// Update Plugin interface to match expected structure
export interface Plugin {
name: string;
handle: (
params: OriginalOpenAI.Chat.ChatCompletionCreateParams,
next: (
p: OriginalOpenAI.Chat.ChatCompletionCreateParams,
) => Promise<OriginalOpenAI.Chat.ChatCompletion>,
) => Promise<OriginalOpenAI.Chat.ChatCompletion>;
}
// Internal Plugin type used internally
type InternalPlugin = {
name: string;
handle: PluginHandler;
};
type PluginHandler = (
params: ChatCompletionParams,
next: (p: ChatCompletionParams, context?: any) => Promise<any>,
) => Promise<any>;
class MCPClient {
private client: Client;
private connected: boolean = false;
transport: MCPTransport | null = null; // Public for cleanup
private tools: any[] = [];
private toolsLoadAttempted: boolean = false;
private userMessages: Message[] = [];
private assistantMessages: Message[] = [];
private toolResponses: Record<string, Message> = {};
private errorCount: number = 0;
private reconnecting: boolean = false;
private cfg: InternalMCPConfig;
private cleanupTimeout: NodeJS.Timeout | null = null;
private isDisconnecting: boolean = false; // Flag to prevent concurrent disconnects
static encTok = (() => {
try {
return require("tiktoken").encoding_for_model("gpt-4");
} catch {
return {
encode: (s: string) =>
new Array(Math.ceil((s || "").length / 4)).fill(0),
};
}
})();
constructor(cfg: MCPConfig = {}) {
log(LVL.debug, "Initializing MCP Client");
if (cfg.logLevel !== undefined && LVL[cfg.logLevel] !== undefined) {
setMcpLogLevel(cfg.logLevel);
}
// Handle both serverUrl and serverUrls
const raw =
cfg.serverUrls ||
(cfg.serverUrl ? [cfg.serverUrl] : null) ||
process.env.MCP_SERVER_URLS ||
process.env.MCP_SERVER_URL ||
"http://0.0.0.0:3000/mcp";
let urls = Array.isArray(raw) ? raw : String(raw).split(",");
urls = urls
.map((u: string) => u.trim())
.filter(Boolean)
.filter((u: string) => /^https?:\/\//i.test(u));
if (!urls.length) urls = ["http://0.0.0.0:3000/mcp"];
this.cfg = {
serverUrls: urls,
headers: cfg.headers || {},
secondPassSystemPrompt:
cfg.finalResponseSystemPrompt ||
cfg.secondPassSystemPrompt ||
"Provide a helpful answer based on the tool results, addressing the user's original question.",
modelName: cfg.modelName || "gpt-4",
maxOutputTokens: cfg.maxOutputTokens ?? 4096,
maxToolCalls: cfg.maxToolCalls ?? 15,
toolTimeoutSec: cfg.toolTimeoutSec ?? 60,
disconnectAfterUse: cfg.disconnectAfterUse ?? true,
connectionTimeoutMs: cfg.connectionTimeoutMs ?? 5_000,
maxMessageGroups: cfg.maxMessageGroups ?? 3,
tokenRateLimit: cfg.tokenRateLimit ?? 29_000,
rateLimitWindowMs: cfg.rateLimitWindowMs ?? 60_000,
noWaitOnTpm: cfg.noWaitOnTpm ?? false,
forceCleanupTimeoutMs: cfg.forceCleanupTimeoutMs ?? 30_000, // 30 seconds timeout by default
};
this.client = new Client({ name: "mcp-client", version: "0.1.0" });
log(LVL.info, `MCP Client initialized with model: ${this.cfg.modelName}`);
// Register this client for tracking
activeConnections.add(this);
// Set up forced cleanup timeout if disconnectAfterUse is true
if (this.cfg.disconnectAfterUse && this.cfg.forceCleanupTimeoutMs > 0) {
this.scheduleForceCleanup();
}
}
// Schedule a forced cleanup
private scheduleForceCleanup() {
if (this.cleanupTimeout) {
clearTimeout(this.cleanupTimeout);
}
this.cleanupTimeout = setTimeout(() => {
log(LVL.warn, "Forced cleanup timeout triggered - disconnecting");
this.disconnect().catch((e) => {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Error during forced disconnect: ${errorMessage}`);
});
}, this.cfg.forceCleanupTimeoutMs);
}
async connect() {
if (this.connected) {
log(LVL.debug, "Already connected, skipping connect");
return;
}
log(
LVL.info,
`Connecting to MCP servers: ${this.cfg.serverUrls.join(", ")}`,
);
let lastErr;
for (const url of this.cfg.serverUrls) {
try {
log(LVL.debug, `Attempting connection to ${url}`);
const transport = new SSEClientTransport(new URL(url), {
requestInit: Object.keys(this.cfg.headers).length
? { headers: this.cfg.headers }
: undefined,
}) as MCPTransport;
await Promise.race([
this.client.connect(transport),
new Promise((_, rej) =>
setTimeout(
rej,
this.cfg.connectionTimeoutMs,
new Error("Connection timeout"),
),
),
]);
if (transport.eventSource) {
transport.eventSource.onerror = (ev: Event) =>
this.#onSSEError(ev, url);
}
this.transport = transport;
this.connected = true;
this.client.setNotificationHandler(
ToolListChangedNotificationSchema,
() => {
log(LVL.info, "Tool list changed – refreshing");
this.updateTools();
},
);
await this.updateTools();
log(LVL.info, `Connected to MCP ${url}. Tools: ${this.tools.length}`);
// Reset the cleanup timeout when we successfully connect
if (this.cfg.disconnectAfterUse) {
this.scheduleForceCleanup();
}
return;
} catch (e: unknown) {
lastErr = e;
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.warn, `Connect failed (${url}) – ${errorMessage}`);
}
}
throw lastErr ?? new Error("MCP: all server URLs failed");
}
#onSSEError(ev: Event, url: string) {
this.errorCount++;
const errorMessage = (ev as any)?.message ?? String(ev);
log(
LVL.warn,
`SSE error (${url}): ${errorMessage}. count=${this.errorCount}`,
);
if (this.errorCount > 3) {
this.tools = [];
return;
}
if (this.reconnecting) return;
this.reconnecting = true;
setTimeout(async () => {
this.reconnecting = false;
try {
await this.updateTools();
} catch (e: unknown) {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.warn, `Reconnect refresh failed: ${errorMessage}`);
}
}, 1_000);
}
async disconnect() {
// Prevent concurrent disconnect operations
if (this.isDisconnecting) {
log(LVL.debug, "Disconnect already in progress, skipping");
return;
}
this.isDisconnecting = true;
// Clear the cleanup timeout if it exists
if (this.cleanupTimeout) {
clearTimeout(this.cleanupTimeout);
this.cleanupTimeout = null;
}
if (this.connected) {
log(LVL.info, "Disconnecting from MCP");
try {
// 1. EventSource cleanup - closes the actual socket
if (this.transport?.eventSource) {
log(LVL.debug, "Closing SSE EventSource connection");
try {
// Remove all event handlers first to prevent callbacks during close
this.transport.eventSource.onmessage = null;
this.transport.eventSource.onerror = null;
this.transport.eventSource.onopen = null;
// Then close the connection
this.transport.eventSource.close();
this.transport.eventSource = undefined;
log(LVL.debug, "SSE EventSource closed successfully");
} catch (sseError) {
const errorMessage =
sseError instanceof Error ? sseError.message : String(sseError);
log(LVL.warn, `Error closing SSE connection: ${errorMessage}`);
// Continue with other cleanup steps
}
}
// 2. Transport cleanup
if (this.transport) {
log(LVL.debug, "Closing transport connection");
try {
await Promise.race([
this.transport.close(),
new Promise((resolve) => setTimeout(resolve, 1000)), // Timeout after 1 second
]);
log(LVL.debug, "Transport closed successfully");
} catch (transportError) {
const errorMessage =
transportError instanceof Error
? transportError.message
: String(transportError);
log(LVL.warn, `Error closing transport: ${errorMessage}`);
}
}
} catch (e: unknown) {
// Catch any unexpected errors in the overall process
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Unexpected error during disconnect: ${errorMessage}`);
} finally {
// Always clean up state, regardless of errors
this.connected = false;
this.transport = null;
activeConnections.delete(this);
this.isDisconnecting = false;
log(LVL.info, "Disconnect complete, all resources released");
}
} else {
// If not connected, just ensure we're removed from active connections
activeConnections.delete(this);
this.isDisconnecting = false;
log(LVL.debug, "No active connection to disconnect");
}
}
async updateTools() {
if (!this.connected) {
log(LVL.debug, "Not connected, skipping tool update");
return (this.tools = []);
}
log(LVL.debug, "Updating MCP tools list");
this.toolsLoadAttempted = true;
try {
const { tools = [] } = (await this.client.listTools()) || {};
this.tools = tools.map((t: any) => ({
name: t.name,
description: t.description || `Use ${t.name}`,
input_schema: t.inputSchema,
categories: (t.categories || []).map((c: string) => c.toLowerCase()),
}));
log(LVL.info, `Loaded ${this.tools.length} tools`);
} catch (e: unknown) {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.warn, `listTools failed: ${errorMessage}`);
}
}
formatTool(t: any) {
let schema =
typeof t.input_schema === "string"
? JSON.parse(t.input_schema)
: t.input_schema || {};
if (schema?.type !== "object") schema = { type: "object", properties: {} };
if (!Object.keys(schema.properties).length)
schema.properties = {
query: { type: "string", description: `Input for ${t.name}` },
};
return {
type: "function",
function: {
name: t.name,
description: t.description,
parameters: schema,
},
};
}
openAITools() {
if (!this.tools.length && !this.toolsLoadAttempted) {
log(LVL.debug, "No tools loaded, attempting to update tools");
this.updateTools().catch(() => {});
}
return this.tools.map((t: any) => this.formatTool(t));
}
buildMsgs() {
const out = [...this.userMessages];
for (const m of this.assistantMessages) {
out.push(m);
m.tool_calls?.forEach((tc: any) => {
const r = this.toolResponses[tc.id];
if (r) out.push(r);
});
}
return out;
}
trim(msgs: Message[]) {
if (msgs.length <= 4) return msgs;
const groups: Message[][] = [];
const cur: Message[] = [];
const flush = () => {
if (cur.length) groups.push(cur.splice(0));
};
for (let i = 0; i < msgs.length; i++) {
cur.push(msgs[i]);
if (msgs[i].role === "assistant" && msgs[i].tool_calls?.length) {
const ids = new Set(msgs[i].tool_calls?.map((t: any) => t.id) || []);
for (let j = i + 1; j < msgs.length && msgs[j].role === "tool"; j++) {
if (ids.has(msgs[j].tool_call_id!)) {
cur.push(msgs[j]);
i = j;
} else break;
}
}
flush();
}
return [
groups[0] || [],
...groups.slice(-this.cfg.maxMessageGroups),
].flat();
}
async processToolCall(tc: any) {
let args;
try {
args = JSON.parse(tc.function.arguments);
} catch {
args = tc.function.arguments;
}
log(LVL.info, `Processing tool call: ${tc.function.name}`);
try {
const r = await this.client.callTool(
{ name: tc.function.name, arguments: args },
CallToolResultSchema,
{ timeout: this.cfg.toolTimeoutSec * 1_000 },
);
const txt = Array.isArray(r.content)
? r.content.map((c: any) => c.text).join("\n\n")
: r.content || "No result";
return (this.toolResponses[tc.id] = {
role: "tool",
tool_call_id: tc.id,
name: tc.function.name,
content:
typeof txt === "string"
? txt.length > 8_000
? txt.slice(0, 8_000) + "\n\n[truncated]"
: txt
: "No result",
});
} catch (e: unknown) {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.warn, `Tool error: ${errorMessage}`);
return (this.toolResponses[tc.id] = {
role: "tool",
tool_call_id: tc.id,
name: tc.function.name,
content: `Error: ${errorMessage}`,
});
}
}
getUserMessages() {
return this.userMessages;
}
getAssistantMessages() {
return this.assistantMessages;
}
getTools() {
return this.tools;
}
isToolsLoadAttempted() {
return this.toolsLoadAttempted;
}
getConfig() {
return this.cfg;
}
}
const multiModelPlugin: InternalPlugin = {
name: "multiModelPlugin",
async handle(params, next) {
log(
LVL.debug,
`MultiModel plugin handling request for model: ${params.model}`,
);
return providerFor(params.model).chat.completions.create(params as any);
},
};
type PluginRegistry = {
[key: string]: (config: any) => InternalPlugin;
};
const PLUGIN_REGISTRY: PluginRegistry = {
mcp: (config: any) => mcpPlugin(config),
multiModel: () => multiModelPlugin,
};
interface EnhancedCompletionParams extends ChatCompletionParams {
return_tool_calls?: boolean;
}
function mcpPlugin(opts: any = {}): InternalPlugin {
return {
name: "mcpPlugin",
async handle(params: ChatCompletionParams, next: PluginHandler) {
log(LVL.info, `MCP plugin handling request for model: ${params.model}`);
const serverConfig =
opts.serverUrls ||
opts.serverUrl ||
process.env.MCP_SERVER_URLS ||
process.env.MCP_SERVER_URL;
if (!serverConfig) {
log(LVL.debug, "No MCP server config found, skipping MCP processing");
return (next as any)(params, undefined);
}
const wantStream = params.stream === true;
log(LVL.debug, `Request stream mode: ${wantStream}`);
const originalSystemMessage = params.messages.find(
(m: any) => m.role === "system",
);
log(LVL.debug, "Creating MCP client");
const mcp = new MCPClient({
...opts,
serverUrls: opts.serverUrls,
modelName: params.model,
maxOutputTokens: params.max_tokens,
// Force disconnectAfterUse to true always, regardless of user config
// This ensures automatic cleanup without requiring client app changes
disconnectAfterUse: true,
});
try {
log(LVL.debug, "Connecting to MCP");
await mcp.connect();
} catch (e: unknown) {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.warn, `MCP unavailable – ${errorMessage}`);
// Ensure we clean up even if connect fails
await mcp.disconnect().catch(() => {});
return (next as any)(params, undefined);
}
try {
log(LVL.debug, `Processing ${params.messages.length} messages`);
params.messages.forEach((m: any) =>
(m.role === "user"
? mcp.getUserMessages()
: mcp.getAssistantMessages()
).push(m),
);
if (!mcp.getTools().length && !mcp.isToolsLoadAttempted())
await mcp.updateTools();
const tools = mcp.openAITools();
log(LVL.debug, `Available tools: ${tools.length}`);
const messagesWithSystem = mcp.trim(mcp.buildMsgs());
const firstPassMessages = originalSystemMessage
? [
originalSystemMessage,
...messagesWithSystem.filter((m: any) => m.role !== "system"),
]
: messagesWithSystem;
log(LVL.info, "Sending first pass request to model");
const first = await (next as any)(
{
model: params.model,
stream: false,
max_tokens: params.max_tokens ?? 4096,
messages: firstPassMessages as any,
...(tools.length && { tools, tool_choice: "auto" }),
},
undefined,
);
const assistant = first.choices[0].message;
const calls = assistant.tool_calls ?? [];
log(
LVL.info,
`First pass response received, tool calls: ${calls.length}`,
);
if (!calls.length) {
log(LVL.debug, "No tool calls, returning direct response");
// Always disconnect MCP client when no tool calls are needed
log(LVL.debug, "Disconnecting MCP client (no tool calls)");
await mcp.disconnect().catch((e) => {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Error disconnecting: ${errorMessage}`);
});
if (wantStream) {
const raw = await (next as any)(
{ ...params, stream: true },
undefined,
);
return asIterable(raw);
}
return first;
}
mcp.getAssistantMessages().push({
role: "assistant",
content: assistant.content,
tool_calls: calls,
});
const summaries = [];
log(
LVL.info,
`Processing ${Math.min(calls.length, mcp.getConfig().maxToolCalls)} tool calls`,
);
for (const tc of calls.slice(0, mcp.getConfig().maxToolCalls))
summaries.push(
`### ${tc.function.name}\n${(await mcp.processToolCall(tc)).content}`,
);
const finalResponseSystemPrompt =
opts.finalResponseSystemPrompt ||
opts.secondPassSystemPrompt ||
mcp.getConfig().secondPassSystemPrompt;
// Create a cleanup promise that will be used to disconnect after streaming is done
let cleanupResolver = () => {}; // Initialize with a no-op function
const cleanupPromise = new Promise<void>((resolve) => {
cleanupResolver = resolve;
});
// Set up a timer to force cleanup if needed
const forceCleanupTimer = setTimeout(() => {
log(LVL.warn, "Force cleanup timer triggered for streaming response");
cleanupResolver();
}, 60000); // 60 second fallback
log(LVL.info, "Sending follow-up request with tool results");
const followParams = {
model: params.model,
stream: wantStream,
max_tokens: params.max_tokens ?? 4096,
messages: [
{ role: "system", content: finalResponseSystemPrompt },
{
role: "user",
content: mcp.getUserMessages().at(-1)?.content || "",
},
{ role: "user", content: summaries.join("\n\n") },
],
};
// Create a wrapper for streaming responses to handle cleanup
if (wantStream) {
log(LVL.debug, "Setting up stream response with auto-cleanup");
// Get the raw follow-up response
const rawFollow = await (next as any)(followParams, undefined);
// Create the wrapped iterator that will handle cleanup
const wrappedIterator = (async function* () {
try {
// Use proper async iteration with cleanup
if (rawFollow && Symbol.asyncIterator in rawFollow) {
try {
for await (const chunk of rawFollow) {
yield chunk;
}
} catch (e) {
const errorMessage =
e instanceof Error ? e.message : String(e);
log(LVL.error, `Error in stream iteration: ${errorMessage}`);
throw e;
} finally {
// Clean up when streaming is done
log(LVL.debug, "Stream completed, disconnecting MCP");
await mcp.disconnect().catch((e) => {
const errorMessage =
e instanceof Error ? e.message : String(e);
log(
LVL.error,
`Error disconnecting after stream: ${errorMessage}`,
);
});
cleanupResolver();
if (forceCleanupTimer) {
clearTimeout(forceCleanupTimer);
}
}
} else {
// Handle non-iterator response
const content =
rawFollow?.choices?.[0]?.message?.content ??
rawFollow?.choices?.[0]?.delta?.content ??
rawFollow?.content ??
JSON.stringify(rawFollow);
yield { choices: [{ delta: { content } }] };
// Clean up after yielding
log(
LVL.debug,
"Non-stream response complete, disconnecting MCP",
);
await mcp.disconnect().catch((e) => {
const errorMessage =
e instanceof Error ? e.message : String(e);
log(
LVL.error,
`Error disconnecting after response: ${errorMessage}`,
);
});
cleanupResolver();
if (forceCleanupTimer) {
clearTimeout(forceCleanupTimer);
}
}
} catch (error) {
// Clean up even on error
const errorMessage =
error instanceof Error ? error.message : String(error);
log(LVL.error, `Stream error: ${errorMessage}`);
await mcp.disconnect().catch(() => {});
cleanupResolver();
if (forceCleanupTimer) {
clearTimeout(forceCleanupTimer);
}
throw error;
}
})();
// Wait for cleanup in the background (failsafe)
cleanupPromise.then(() => {
mcp.disconnect().catch(() => {});
});
return wrappedIterator;
}
// For non-streaming responses, the process is simpler
try {
log(LVL.debug, "Processing non-streaming follow-up response");
const follow = await (next as any)(followParams, undefined);
// Process the follow-up response
log(LVL.debug, "Processing final response");
let final = "";
if (follow && Symbol.asyncIterator in follow) {
for await (const ch of follow) {
final += ch.choices?.[0]?.delta?.content || "";
}
} else if (follow?.choices?.[0]?.message?.content) {
final = follow.choices[0].message.content;
}
assistant.content = final;
// Always disconnect after processing
log(LVL.debug, "Response complete, disconnecting MCP");
await mcp.disconnect().catch((e) => {
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Error disconnecting: ${errorMessage}`);
});
cleanupResolver();
if (forceCleanupTimer) {
clearTimeout(forceCleanupTimer);
}
log(LVL.info, "Request completed successfully");
const paramsWithToolCalls = params as EnhancedCompletionParams;
return {
id: `chatcmpl-${Date.now()}`,
object: "chat.completion",
created: Math.floor(Date.now() / 1e3),
model: params.model,
usage: first.usage,
choices: [
{
index: 0,
finish_reason:
calls.length && paramsWithToolCalls.return_tool_calls
? "tool_calls"
: "stop",
message: {
role: "assistant",
content: assistant.content,
tool_calls: paramsWithToolCalls.return_tool_calls
? calls
: undefined,
},
},
],
};
} catch (e) {
// Clean up in error case
const errorMessage = e instanceof Error ? e.message : String(e);
log(LVL.error, `Error in follow-up response: ${errorMessage}`);
await mcp.disconnect().catch(() => {});
cleanupResolver();
if (forceCleanupTimer) {
clearTimeout(forceCleanupTimer);
}
throw e;
}
} catch (e) {
// Always ensure connection is closed in case of errors
log(LVL.debug, "Error in MCP plugin handler, disconnecting");
await mcp.disconnect().catch((err) => {
const errorMessage = err instanceof Error ? err.message : String(err);
log(LVL.error, `Error disconnecting after error: ${errorMessage}`);
});
throw e;
}
},
};
}
function compose(
plugins: InternalPlugin[],
base: PluginHandler,
): PluginHandler {
return plugins.reduceRight(
(next, plugin) => (params) => plugin.handle(params, next),
base,
);
}
type ChatCompletionsCreate =
typeof OriginalOpenAI.prototype.chat.completions.create;
// Explicitly include all required fields from the OpenAI SDK
export interface OpenAIOptions {
apiKey?: string;
organization?: string;
baseURL?: string;
timeout?: number;
maxRetries?: number;
defaultQuery?: Record<string, string>;
defaultHeaders?: Record<string, string>;
dangerouslyAllowBrowser?: boolean;
plugins?: string[] | Plugin[] | string | null;
pluginConfig?: Record<string, any>;
mcp?: MCPConfig;
mcpLogLevel?: MCP_LogLevel;
}
// Main OpenAI class implementation
class OpenAI extends OriginalOpenAI {
constructor(options: OpenAIOptions = {}) {
// Pass all standard OpenAI options to the parent constructor
super({
apiKey: options.apiKey,
organization: options.organization,
baseURL: options.baseURL,
timeout: options.timeout,
maxRetries: options.maxRetries,
defaultQuery: options.defaultQuery,
defaultHeaders: options.defaultHeaders,
dangerouslyAllowBrowser: options.dangerouslyAllowBrowser,
});
if (options.mcpLogLevel || (options.mcp && options.mcp.logLevel)) {
setMcpLogLevel(options.mcpLogLevel || options.mcp?.logLevel || "debug");
}
log(LVL.info, `Initializing OpenAI client with plugins`);
globalApiKey = options.apiKey || null;
// Configure MCP plugin with mcp config if present
const pluginConfig = {
...(options.pluginConfig || {}),
mcp: options.mcp || {},
};
const activePlugins = this.#loadPlugins(
options.plugins || null,
pluginConfig,
);
const originalCreate = this.chat.completions.create.bind(
this.chat.completions,
) as ChatCompletionsCreate;
const handler = compose(activePlugins, (p) => originalCreate(p as any));
this.chat.completions.create = handler as unknown as ChatCompletionsCreate;
log(
LVL.info,
`OpenAI client initialized with ${activePlugins.length} plugins`,
);
}
#loadPlugins(
plugins: string | string[] | Plugin[] | null,
config: Record<string, any>,
): InternalPlugin[] {
// Handle case of Plugin objects directly
if (
Array.isArray(plugins) &&
plugins.length > 0 &&
typeof plugins[0] === "object"
) {
// Convert to internal plugin format
const pluginObjects = plugins as Plugin[];
const uniquePluginNames = pluginObjects
.map((p) => p.name)
.filter((name, index, self) => self.indexOf(name) === index);
log(
LVL.debug,
`Loading ${uniquePluginNames.length} object plugins: ${uniquePluginNames.join(", ")}`,
);
// Adapter to convert external Plugin to internal plugin format
return pluginObjects.map((p) => ({
name: p.name,
handle: async (params: any, next: any) => {
return p.handle(params, next);
},
}));
}
// Handle case of string plugin names
else if (plugins) {
const pluginNames: string[] = Array.isArray(plugins)
? Array.from(new Set(plugins as string[]))
: [plugins as string];
log(
LVL.debug,
`Loading ${pluginNames.length} plugins: ${pluginNames.join(", ")}`,
);
const pluginMap = new Map<string, InternalPlugin>();
pluginNames.forEach((name) => {
if (name === "mcp" && !pluginMap.has("mcp")) {
pluginMap.set("mcp", mcpPlugin(config.mcp || {}));
}
if (name === "multiModel" && !pluginMap.has("multiModel")) {
pluginMap.set("multiModel", multiModelPlugin);
}
});
return Array.from(pluginMap.values());
}
// Default to empty plugins array
return [];
}
}
// Export both as default and named - this is critical for it to work correctly
export default OpenAI;
export { OpenAI };
// For CommonJS compatibility
if (typeof module !== "undefined") {
module.exports = OpenAI;
module.exports.OpenAI = OpenAI;
module.exports.default = OpenAI;
}