UNPKG

@4ourlab/mcp-client-gemini

Version:

MCP (Model Context Protocol) implementation for Gemini models

256 lines 11.1 kB
"use strict"; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.MCPClient = void 0; const index_js_1 = require("@modelcontextprotocol/sdk/client/index.js"); const generative_ai_1 = require("@google/generative-ai"); const stdio_js_1 = require("@modelcontextprotocol/sdk/client/stdio.js"); const promises_1 = __importDefault(require("readline/promises")); const fs_1 = __importDefault(require("fs")); const path_1 = __importDefault(require("path")); const packageJson = JSON.parse(fs_1.default.readFileSync(path_1.default.join(process.cwd(), "package.json"), "utf-8")); class MCPClient { mcpConnections = []; genAI; model; serverConfigPath; systemPrompt; constructor(apiKey, model, serverConfigPath, systemPrompt = "") { this.genAI = new generative_ai_1.GoogleGenerativeAI(apiKey); this.model = model; this.serverConfigPath = serverConfigPath; this.systemPrompt = systemPrompt; } async connectToServers() { /** * Connect to multiple MCP servers using configuration from mcpserver.json */ try { const config = await this.loadServerConfig(); const serverEntries = Object.entries(config.mcpServers); // Connect to all servers in parallel const connectionPromises = serverEntries.map(async ([serverName, serverConfig]) => { try { return await this.getConnectionServer(serverName, serverConfig); } catch (error) { throw new Error(`Failed to connect to server ${serverName}: ${error}`); } }); const connections = await Promise.all(connectionPromises); // Filter out failed connections this.mcpConnections = connections.filter((conn) => conn !== null); if (this.mcpConnections.length === 0) { throw new Error("Failed to connect to any MCP servers"); } } catch (error) { throw new Error(`Failed to connect to MCP servers: ${error}`); } } async loadServerConfig() { /** * Load and parse the MCP server configuration file */ const configPath = path_1.default.join(process.cwd(), this.serverConfigPath); if (!fs_1.default.existsSync(configPath)) { throw new Error(`${this.serverConfigPath} not found in current directory`); } const configContent = fs_1.default.readFileSync(configPath, "utf-8"); const config = JSON.parse(configContent); if (!config.mcpServers || Object.keys(config.mcpServers).length === 0) { throw new Error(`No server configuration found in ${this.serverConfigPath}`); } return config; } async getConnectionServer(serverName, serverConfig) { /** * Connect to a single MCP server and return connection details */ const transport = new stdio_js_1.StdioClientTransport({ command: serverConfig.command, args: serverConfig.args, }); const client = new index_js_1.Client({ name: packageJson.name, version: packageJson.version }); client.connect(transport); // List available tools for this server const toolsResult = await client.listTools(); const tools = toolsResult.tools.map((tool) => { return { name: tool.name, description: tool.description || "", input_schema: tool.inputSchema, serverName: serverName, }; }); return { serverName, transport, client, tools, }; } async processQuery(query) { /** * Process a query using Gemini and available tools * * @param query - The user's input query * @returns Processed response as a string */ try { const model = this.genAI.getGenerativeModel({ model: this.model }); // Get all tools from all connected servers const allTools = this.getAllTools(); // Convert tools to Gemini format and clean the schema const functionDeclarations = allTools.map((tool) => { // Clean the schema to remove unsupported fields const cleanSchema = { ...tool.input_schema }; delete cleanSchema.$schema; delete cleanSchema.additionalProperties; return { name: tool.name, description: tool.description, parameters: cleanSchema }; }); const geminiTools = [{ functionDeclarations: functionDeclarations }]; // Prepare conversation contents const contents = []; // Add the user query contents.push({ role: "user", parts: [{ text: query }] }); // Initial Gemini API call const result = await model.generateContent({ contents: contents, tools: geminiTools, systemInstruction: this.systemPrompt || undefined, }); const response = result.response; const finalText = []; // Process response and handle tool calls if (response.candidates && response.candidates.length > 0) { const candidate = response.candidates[0]; if (candidate.content && candidate.content.parts) { for (const part of candidate.content.parts) { if (part.text) { finalText.push(part.text); } else if (part.functionCall) { // Execute tool call const toolName = part.functionCall.name; const toolArgs = part.functionCall.args; try { // Find which server contains this tool const serverConnection = this.findToolServer(toolName); if (!serverConnection) { throw new Error(`Tool ${toolName} not found in any connected server`); } const result = await serverConnection.client.callTool({ name: toolName, arguments: toolArgs, }); // Continue conversation with tool results const toolResult = result.content; const followUpContents = [ { role: "user", parts: [{ text: query }] }, { role: "model", parts: [{ text: `Tool result: ${toolResult}` }] } ]; const followUpResult = await model.generateContent({ contents: followUpContents, systemInstruction: this.systemPrompt || undefined, }); const followUpResponse = followUpResult.response; if (followUpResponse.candidates && followUpResponse.candidates.length > 0) { const candidate = followUpResponse.candidates[0]; if (candidate.content && candidate.content.parts) { for (const part of candidate.content.parts) { if (part.text) { finalText.push(part.text); } } } } else { // If no response from follow-up, use the tool result directly finalText.push(`Tool result: ${toolResult}`); } } catch (toolError) { throw new Error(`Error executing tool ${toolName}: ${toolError}`); } } } } } else { finalText.push("No response received from Gemini"); } // Ensure we always have a response if (finalText.length === 0) { finalText.push("The request could not be processed. Please try rephrasing your question."); } return finalText.join("\n"); } catch (error) { throw new Error(`Error processing query: ${error}`); } } getAllTools() { /** * Get all tools from all connected servers */ return this.mcpConnections.flatMap(connection => connection.tools); } findToolServer(toolName) { /** * Find which server contains a specific tool */ return this.mcpConnections.find(connection => connection.tools.some(tool => tool.name === toolName)) || null; } async chatLoop() { /** * Run an interactive chat loop * Type 'quit' to exit the loop */ const rl = promises_1.default.createInterface({ input: process.stdin, output: process.stdout, }); console.log("\nChat started!\nType your queries or 'quit' to exit."); try { while (true) { const message = await rl.question("\nQuery: "); if (message.toLowerCase() === "quit") { break; } const response = await this.processQuery(message); console.log("\nResponse:\n" + response); } } catch (error) { console.error("Error in chatLoop:", error); } finally { rl.close(); } } async cleanup() { /** * Clean up resources for all connected servers */ const cleanupPromises = this.mcpConnections.map(async (connection) => { try { await connection.client.close(); } catch (error) { throw new Error(`Error closing connection to server ${connection.serverName}: ${error}`); } }); await Promise.all(cleanupPromises); } } exports.MCPClient = MCPClient; //# sourceMappingURL=mcpClient.js.map