UNPKG

@juspay/neurolink

Version:

Universal AI Development Platform with working MCP integration, multi-provider support, voice (TTS/STT/realtime), and professional CLI. 58+ external MCP servers discoverable, multimodal file processing, RAG pipelines. Build, test, and deploy AI applicatio

710 lines (709 loc) 28.8 kB
/** * MCP Client Factory * Creates and manages MCP clients for external servers * Supports stdio, SSE, WebSocket, and HTTP transports * Enhanced with retry, rate limiting, and OAuth 2.1 support */ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; import { WebSocketClientTransport } from "@modelcontextprotocol/sdk/client/websocket.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { spawn } from "child_process"; import { mcpLogger } from "../utils/logger.js"; import { globalCircuitBreakerManager } from "./mcpCircuitBreaker.js"; import { CircuitBreakerOpenError } from "../types/index.js"; import { withHTTPRetry, DEFAULT_HTTP_RETRY_CONFIG, } from "./httpRetryHandler.js"; import { globalRateLimiterManager } from "./httpRateLimiter.js"; import { NeuroLinkOAuthProvider, InMemoryTokenStorage } from "./auth/index.js"; import { SpanSerializer, SpanType, SpanStatus, getMetricsAggregator, } from "../observability/index.js"; import { getActiveTraceContext } from "../telemetry/traceContext.js"; /** * Default timeout for MCP client creation in milliseconds. * Configurable via MCP_CLIENT_TIMEOUT env var. * Covers process spawn, transport setup, connection, and handshake. * Set to 60s to accommodate stdio servers that may be slow to start, * especially when multiple MCP servers are started concurrently. */ const DEFAULT_CLIENT_TIMEOUT = Math.max(5000, Number(process.env.MCP_CLIENT_TIMEOUT) || 60000); /** * MCPClientFactory * Factory class for creating MCP clients with different transports */ export class MCPClientFactory { static NEUROLINK_IMPLEMENTATION = { name: "neurolink-sdk", version: "1.0.0", }; static DEFAULT_CAPABILITIES = { sampling: {}, roots: { listChanged: false, }, }; /** * Create an MCP client for the given server configuration * Enhanced with retry logic, rate limiting, and circuit breaker protection */ static async createClient(config, timeout = DEFAULT_CLIENT_TIMEOUT) { const startTime = Date.now(); const { traceId, parentSpanId } = getActiveTraceContext(); const obsSpan = SpanSerializer.createSpan(SpanType.MCP_TRANSPORT, "mcp.connect", { "mcp.transport": config.transport, "mcp.operation": "connect", "mcp.server_id": config.id, }, parentSpanId, traceId); try { mcpLogger.info(`[MCPClientFactory] Creating client for ${config.id}`, { transport: config.transport, command: config.command, hasRetryConfig: !!config.retryConfig, hasRateLimiting: !!config.rateLimiting, hasAuth: !!config.auth, }); // Acquire rate limit token if rate limiting is configured for HTTP transport if ((config.transport === "http" || config.transport === "sse") && config.rateLimiting) { const rateLimiter = globalRateLimiterManager.getLimiter(config.id, { requestsPerWindow: config.rateLimiting.requestsPerMinute ?? 60, windowMs: 60000, maxBurst: config.rateLimiting.maxBurst ?? 10, useTokenBucket: config.rateLimiting.useTokenBucket ?? true, refillRate: (config.rateLimiting.requestsPerMinute ?? 60) / 60, }); await rateLimiter.acquire(); mcpLogger.debug(`[MCPClientFactory] Rate limit token acquired for ${config.id}`); } // Create circuit breaker for this server const circuitBreaker = globalCircuitBreakerManager.getBreaker(`mcp-client-${config.id}`, { failureThreshold: 3, resetTimeout: 30000, operationTimeout: timeout, }); // Define the client creation operation const createClientOperation = async () => { return await circuitBreaker.execute(async () => { return await this.createClientInternal(config, timeout); }); }; // Wrap with retry logic if retry config is provided for HTTP transport let result; if ((config.transport === "http" || config.transport === "sse") && config.retryConfig) { mcpLogger.debug(`[MCPClientFactory] Using retry logic for ${config.id}`, { maxAttempts: config.retryConfig.maxAttempts ?? DEFAULT_HTTP_RETRY_CONFIG.maxAttempts, }); result = await withHTTPRetry(createClientOperation, { maxAttempts: config.retryConfig.maxAttempts ?? DEFAULT_HTTP_RETRY_CONFIG.maxAttempts, initialDelay: config.retryConfig.initialDelay ?? DEFAULT_HTTP_RETRY_CONFIG.initialDelay, maxDelay: config.retryConfig.maxDelay ?? DEFAULT_HTTP_RETRY_CONFIG.maxDelay, backoffMultiplier: config.retryConfig.backoffMultiplier ?? DEFAULT_HTTP_RETRY_CONFIG.backoffMultiplier, }); } else { result = await createClientOperation(); } mcpLogger.info(`[MCPClientFactory] Client created successfully for ${config.id}`, { duration: Date.now() - startTime, capabilities: result.capabilities, }); obsSpan.durationMs = Date.now() - startTime; const endedObsSpan = SpanSerializer.endSpan(obsSpan, SpanStatus.OK); getMetricsAggregator().recordSpan(endedObsSpan); return { ...result, success: true, duration: Date.now() - startTime, }; } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); // Circuit breaker open: log at warn (not error) since this is expected // protection behavior, and preserve the structured metadata. if (error instanceof CircuitBreakerOpenError) { mcpLogger.warn(`[MCPClientFactory] Client creation blocked by circuit breaker for ${config.id}`, { serverId: config.id, breakerState: error.breakerState, retryAfter: error.retryAfter, retryAfterMs: error.retryAfterMs, failureCount: error.failureCount, }); obsSpan.durationMs = Date.now() - startTime; const endedObsSpan = SpanSerializer.endSpan(obsSpan, SpanStatus.ERROR); endedObsSpan.statusMessage = `Circuit breaker open: ${errorMessage}`; getMetricsAggregator().recordSpan(endedObsSpan); return { success: false, error: errorMessage, duration: Date.now() - startTime, }; } mcpLogger.error(`[MCPClientFactory] Failed to create client for ${config.id}:`, error); obsSpan.durationMs = Date.now() - startTime; const endedObsSpan = SpanSerializer.endSpan(obsSpan, SpanStatus.ERROR); endedObsSpan.statusMessage = errorMessage; getMetricsAggregator().recordSpan(endedObsSpan); return { success: false, error: errorMessage, duration: Date.now() - startTime, }; } } /** * Internal client creation logic */ static async createClientInternal(config, timeout) { // Create transport const transportResult = await this.createTransport(config); // Extract transport and process with necessary type assertions // Note: Type assertions required due to TransportResult using 'unknown' to avoid circular imports const transport = transportResult.transport; const process = transportResult.process; try { // Create client const client = new Client(this.NEUROLINK_IMPLEMENTATION, { capabilities: this.DEFAULT_CAPABILITIES, }); // Connect with timeout await Promise.race([ client.connect(transport), this.createTimeoutPromise(timeout, `Client connection timeout for ${config.id}`), ]); // Perform handshake to get server capabilities const serverCapabilities = await this.performHandshake(client, timeout); mcpLogger.debug(`[MCPClientFactory] Handshake completed for ${config.id}`, { capabilities: serverCapabilities, }); return { client, transport, process, capabilities: serverCapabilities, }; } catch (error) { // Clean up on failure try { await transport.close(); } catch (closeError) { mcpLogger.debug(`[MCPClientFactory] Error closing transport during cleanup:`, closeError); } if (process && !process.killed) { process.kill("SIGTERM"); } throw error; } } /** * Create transport based on configuration */ static async createTransport(config) { switch (config.transport) { case "stdio": return this.createStdioTransport(config); case "sse": return this.createSSETransport(config); case "websocket": return this.createWebSocketTransport(config); case "http": return this.createHTTPTransport(config); default: throw new Error(`Unsupported transport type: ${config.transport}`); } } /** * Create stdio transport with process spawning */ static async createStdioTransport(config) { mcpLogger.debug(`[MCPClientFactory] Creating stdio transport for ${config.id}`, { command: config.command, args: config.args, }); // Validate command is present if (!config.command) { throw new Error(`Command is required for stdio transport`); } // Spawn the process const childProcess = spawn(config.command, config.args || [], { stdio: ["pipe", "pipe", "pipe"], env: Object.fromEntries(Object.entries({ ...process.env, ...config.env, }) .filter(([, value]) => value !== undefined) .map(([k, v]) => [k, String(v)])), cwd: config.cwd, }); // Handle process errors const processErrorPromise = new Promise((_, reject) => { childProcess.on("error", (error) => { reject(new Error(`Process spawn error: ${error.message}`)); }); childProcess.on("exit", (code, signal) => { if (code !== 0) { reject(new Error(`Process exited with code ${code}, signal ${signal}`)); } }); }); // Wait for process to be ready or fail using AbortController for better async patterns const processStartupController = new AbortController(); const processStartupTimeout = setTimeout(() => { processStartupController.abort(); }, 1000); try { await Promise.race([ new Promise((resolve) => { const checkReady = () => { if (processStartupController.signal.aborted) { resolve(); // Timeout reached, continue } else { setTimeout(checkReady, 100); } }; checkReady(); }), processErrorPromise, ]); } finally { clearTimeout(processStartupTimeout); } // Check if process is still running if (childProcess.killed || childProcess.exitCode !== null) { throw new Error("Process failed to start or exited immediately"); } // Create transport if (!config.command) { throw new Error(`Command is required for stdio transport`); } const transport = new StdioClientTransport({ command: config.command, args: config.args || [], env: Object.fromEntries(Object.entries({ ...process.env, ...config.env, }) .filter(([, value]) => value !== undefined) .map(([key, value]) => [key, String(value)])), cwd: config.cwd, stderr: "ignore", // Suppress MCP server startup messages }); return { transport, process: childProcess }; } /** * Create SSE transport */ static async createSSETransport(config) { if (!config.url) { throw new Error("URL is required for SSE transport"); } mcpLogger.debug(`[MCPClientFactory] Creating SSE transport for ${config.id}`, { url: config.url, }); try { const url = new URL(config.url); const transport = new SSEClientTransport(url); return { transport }; } catch (error) { throw new Error(`Invalid SSE URL: ${error instanceof Error ? error.message : String(error)}`, { cause: error }); } } /** * Create WebSocket transport */ static async createWebSocketTransport(config) { if (!config.url) { throw new Error("URL is required for WebSocket transport"); } mcpLogger.debug(`[MCPClientFactory] Creating WebSocket transport for ${config.id}`, { url: config.url, }); try { const url = new URL(config.url); const transport = new WebSocketClientTransport(url); return { transport }; } catch (error) { throw new Error(`Invalid WebSocket URL: ${error instanceof Error ? error.message : String(error)}`, { cause: error }); } } /** * Create HTTP transport (Streamable HTTP) * Enhanced with OAuth 2.1, rate limiting, and configurable timeouts */ static async createHTTPTransport(config) { if (!config.url) { throw new Error("URL is required for HTTP transport"); } // Extract HTTP options with defaults const httpOptions = { connectionTimeout: config.httpOptions?.connectionTimeout ?? 30000, requestTimeout: config.httpOptions?.requestTimeout ?? 60000, idleTimeout: config.httpOptions?.idleTimeout ?? 120000, keepAliveTimeout: config.httpOptions?.keepAliveTimeout ?? 30000, }; mcpLogger.debug(`[MCPClientFactory] Creating HTTP transport for ${config.id}`, { url: config.url, hasHeaders: !!config.headers, hasAuth: !!config.auth, authType: config.auth?.type, httpOptions, }); try { const url = new URL(config.url); // Set up OAuth provider if configured const oauthProvider = await this.setupAuthProvider(config); // Build headers including authentication const headers = { ...(config.headers ?? {}), }; // Add authentication headers based on auth type if (config.auth) { const authHeader = await this.getAuthorizationHeader(config, oauthProvider); if (authHeader) { headers["Authorization"] = authHeader; } } // Create custom fetch wrapper with timeout and rate limiting support const fetchWithEnhancements = this.createEnhancedFetch(config, httpOptions.requestTimeout, oauthProvider); // Create request init with custom headers const requestInit = { headers: Object.keys(headers).length > 0 ? headers : undefined, }; const transport = new StreamableHTTPClientTransport(url, { requestInit, fetch: fetchWithEnhancements, }); return { transport }; } catch (error) { throw new Error(`Invalid HTTP URL: ${error instanceof Error ? error.message : String(error)}`, { cause: error }); } } /** * Create a fetch wrapper with timeout support */ static createFetchWithTimeout(timeoutMs) { return async (input, init) => { const controller = new AbortController(); const timeoutId = setTimeout(() => controller.abort(), timeoutMs); try { return await fetch(input, { ...init, signal: controller.signal }); } finally { clearTimeout(timeoutId); } }; } /** * Create an enhanced fetch function with timeout and optional retry */ static createEnhancedFetch(config, timeoutMs, oauthProvider) { const fetchWithTimeout = this.createFetchWithTimeout(timeoutMs); return async (input, init) => { // If OAuth is configured, ensure we have valid tokens if (oauthProvider && config.auth?.type === "oauth2") { try { const authHeader = await oauthProvider.getAuthorizationHeader(config.id); if (authHeader) { const existingHeaders = init?.headers ?? {}; const headers = new Headers(existingHeaders); headers.set("Authorization", authHeader); init = { ...init, headers, }; } } catch (error) { mcpLogger.warn(`[MCPClientFactory] OAuth token refresh failed for ${config.id}:`, error instanceof Error ? error.message : String(error)); // Continue without auth - let the request fail naturally } } return fetchWithTimeout(input, init); }; } /** * Set up OAuth provider if configured */ static async setupAuthProvider(config) { if (config.auth?.type === "oauth2" && config.auth.oauth) { const tokenStorage = new InMemoryTokenStorage(); const oauthConfig = { clientId: config.auth.oauth.clientId, clientSecret: config.auth.oauth.clientSecret, authorizationUrl: config.auth.oauth.authorizationUrl, tokenUrl: config.auth.oauth.tokenUrl, redirectUrl: config.auth.oauth.redirectUrl, scope: config.auth.oauth.scope, usePKCE: config.auth.oauth.usePKCE ?? true, }; const provider = new NeuroLinkOAuthProvider(oauthConfig, tokenStorage); mcpLogger.debug(`[MCPClientFactory] OAuth provider created for ${config.id}`, { clientId: oauthConfig.clientId, usePKCE: oauthConfig.usePKCE, }); return provider; } return undefined; } /** * Get authorization header based on auth configuration */ static async getAuthorizationHeader(config, oauthProvider) { if (!config.auth) { return undefined; } switch (config.auth.type) { case "oauth2": if (oauthProvider) { const header = await oauthProvider.getAuthorizationHeader(config.id); return header ?? undefined; } return undefined; case "bearer": if (config.auth.token) { return `Bearer ${config.auth.token}`; } return undefined; case "api-key": // API key is typically sent as a custom header, not Authorization // But if needed, we can return it here return undefined; default: return undefined; } } /** * Perform MCP handshake and get server capabilities */ static async performHandshake(client, timeout) { try { // The MCP SDK handles the handshake automatically during connect() // We can request server info to verify the connection const serverInfo = await Promise.race([ this.getServerInfo(client), this.createTimeoutPromise(timeout, "Handshake timeout"), ]); // Extract capabilities from server info return this.extractCapabilities(serverInfo); } catch (error) { mcpLogger.warn("[MCPClientFactory] Handshake failed, but connection may still be valid:", error); // Return default capabilities if handshake fails // The connection might still work for basic operations return this.DEFAULT_CAPABILITIES; } } /** * Get server information */ static async getServerInfo(client) { try { // Try to list tools to verify server is responding const toolsResult = await client.listTools(); return { tools: toolsResult.tools || [], capabilities: this.DEFAULT_CAPABILITIES, }; } catch { // If listing tools fails, try a simpler ping mcpLogger.debug("[MCPClientFactory] Tool listing failed, server may not support tools yet"); return { tools: [], capabilities: this.DEFAULT_CAPABILITIES, }; } } /** * Extract capabilities from server info */ static extractCapabilities(serverInfo) { // For now, return default capabilities // This can be enhanced when MCP servers provide more detailed capability info return { ...this.DEFAULT_CAPABILITIES, ...(serverInfo.tools ? { tools: {} } : {}), }; } /** * Create a timeout promise with AbortController support * Provides consistent async timeout patterns across the factory */ static createTimeoutPromise(timeout, message, abortSignal) { return new Promise((_, reject) => { const timeoutId = setTimeout(() => { reject(new Error(message)); }, timeout); // Support abortion for better async cleanup if (abortSignal) { abortSignal.addEventListener("abort", () => { clearTimeout(timeoutId); reject(new Error(`Operation aborted: ${message}`)); }); } }); } /** * Close an MCP client and clean up resources */ static async closeClient(client, transport, process) { const errors = []; // Close client try { await client.close(); } catch (error) { errors.push(`Client close error: ${error instanceof Error ? error.message : String(error)}`); } // Close transport try { await transport.close(); } catch (error) { errors.push(`Transport close error: ${error instanceof Error ? error.message : String(error)}`); } // Kill process if exists with proper async cleanup if (process && !process.killed) { try { process.kill("SIGTERM"); // Use Promise-based approach for force kill timeout await new Promise((resolve) => { const forceKillTimeout = setTimeout(() => { if (!process.killed) { mcpLogger.warn("[MCPClientFactory] Force killing process"); try { process.kill("SIGKILL"); } catch (killError) { mcpLogger.debug("[MCPClientFactory] Error in force kill:", killError); } } resolve(); }, 5000); // If process exits gracefully before timeout, clear the force kill process.on("exit", () => { clearTimeout(forceKillTimeout); resolve(); }); }); } catch (error) { errors.push(`Process kill error: ${error instanceof Error ? error.message : String(error)}`); } } if (errors.length > 0) { mcpLogger.warn("[MCPClientFactory] Errors during client cleanup:", errors); } } /** * Test connection to an MCP server */ static async testConnection(config, timeout = 5000) { let client; let transport; let process; try { const result = await this.createClient(config, timeout); if (!result.success) { return { success: false, error: result.error }; } client = result.client; transport = result.transport; process = result.process; // Try to list tools as a connectivity test if (client) { try { await client.listTools(); } catch { // Tool listing failure doesn't necessarily mean connection failure mcpLogger.debug("[MCPClientFactory] Tool listing failed during test, but connection may be valid"); } } return { success: true, capabilities: result.capabilities, }; } catch (error) { return { success: false, error: error instanceof Error ? error.message : String(error), }; } finally { // Clean up test connection if (client && transport) { try { await this.closeClient(client, transport, process); } catch (error) { mcpLogger.debug("[MCPClientFactory] Error cleaning up test connection:", error); } } } } /** * Validate MCP server configuration for client creation */ static validateClientConfig(config) { const errors = []; // Basic validation if (!config.command) { errors.push("Command is required"); } if (!config.transport) { errors.push("Transport is required"); } if (!["stdio", "sse", "websocket", "http"].includes(config.transport)) { errors.push("Transport must be stdio, sse, websocket, or http"); } // Transport-specific validation if (config.transport === "sse" || config.transport === "websocket" || config.transport === "http") { if (!config.url) { errors.push(`URL is required for ${config.transport} transport`); } else { try { new URL(config.url); } catch { errors.push(`Invalid URL for ${config.transport} transport`); } } } if (config.transport === "stdio") { if (!Array.isArray(config.args)) { errors.push("Args array is required for stdio transport"); } } return { isValid: errors.length === 0, errors, }; } /** * Get supported transport types */ static getSupportedTransports() { return ["stdio", "sse", "websocket", "http"]; } /** * Get default client capabilities */ static getDefaultCapabilities() { return { ...this.DEFAULT_CAPABILITIES }; } }