UNPKG

mcp-framework

Version:

Framework for building Model Context Protocol (MCP) servers in Typescript

427 lines (426 loc) 19.3 kB
import { Server } from "@modelcontextprotocol/sdk/server/index.js"; import { CallToolRequestSchema, ListToolsRequestSchema, ListPromptsRequestSchema, GetPromptRequestSchema, ListResourcesRequestSchema, ListResourceTemplatesRequestSchema, ReadResourceRequestSchema, SubscribeRequestSchema, UnsubscribeRequestSchema, } from "@modelcontextprotocol/sdk/types.js"; import { readFileSync } from "fs"; import { join, resolve, dirname } from "path"; import { logger } from "./Logger.js"; import { ToolLoader } from "../loaders/toolLoader.js"; import { PromptLoader } from "../loaders/promptLoader.js"; import { ResourceLoader } from "../loaders/resourceLoader.js"; import { StdioServerTransport } from "../transports/stdio/server.js"; import { SSEServerTransport } from "../transports/sse/server.js"; import { DEFAULT_SSE_CONFIG } from "../transports/sse/types.js"; import { HttpStreamTransport } from "../transports/http/server.js"; import { DEFAULT_HTTP_STREAM_CONFIG } from "../transports/http/types.js"; import { DEFAULT_CORS_CONFIG } from "../transports/sse/types.js"; import { createRequire } from 'module'; const require = createRequire(import.meta.url); function isRequest(msg) { return msg && typeof msg.method === 'string' && msg.jsonrpc === "2.0" && 'id' in msg; } function isResponse(msg) { return msg && msg.jsonrpc === "2.0" && 'id' in msg && ('result' in msg || 'error' in msg); } function isNotification(msg) { return msg && typeof msg.method === 'string' && msg.jsonrpc === "2.0" && !('id' in msg); } export class MCPServer { server; toolsMap = new Map(); promptsMap = new Map(); resourcesMap = new Map(); toolLoader; promptLoader; resourceLoader; serverName; serverVersion; basePath; transportConfig; capabilities = {}; // Initialize as empty isRunning = false; transport; shutdownPromise; shutdownResolve; constructor(config = {}) { this.basePath = this.resolveBasePath(config.basePath); this.serverName = config.name ?? this.getDefaultName(); this.serverVersion = config.version ?? this.getDefaultVersion(); this.transportConfig = config.transport ?? { type: "stdio" }; if (this.transportConfig.auth && this.transportConfig.options) { this.transportConfig.options.auth = this.transportConfig.auth; } else if (this.transportConfig.auth && !this.transportConfig.options) { this.transportConfig.options = { auth: this.transportConfig.auth }; } logger.info(`Initializing MCP Server: ${this.serverName}@${this.serverVersion}`); logger.debug(`Base path: ${this.basePath}`); logger.debug(`Transport config: ${JSON.stringify(this.transportConfig)}`); this.toolLoader = new ToolLoader(this.basePath); this.promptLoader = new PromptLoader(this.basePath); this.resourceLoader = new ResourceLoader(this.basePath); this.server = new Server({ name: this.serverName, version: this.serverVersion }, { capabilities: this.capabilities }); logger.debug(`SDK Server instance created.`); } resolveBasePath(configPath) { if (configPath) { return configPath; } if (process.argv[1]) { return process.argv[1]; } return process.cwd(); } createTransport() { logger.debug(`Creating transport: ${this.transportConfig.type}`); let transport; const options = this.transportConfig.options || {}; const authConfig = this.transportConfig.auth ?? options.auth; switch (this.transportConfig.type) { case "sse": { const sseConfig = { ...DEFAULT_SSE_CONFIG, ...options, cors: { ...DEFAULT_CORS_CONFIG, ...options.cors }, auth: authConfig }; transport = new SSEServerTransport(sseConfig); break; } case "http-stream": { const httpConfig = { ...DEFAULT_HTTP_STREAM_CONFIG, ...options, cors: { ...DEFAULT_CORS_CONFIG, ...(options.cors || {}) }, session: { ...DEFAULT_HTTP_STREAM_CONFIG.session, ...(options.session || {}) }, resumability: { ...DEFAULT_HTTP_STREAM_CONFIG.resumability, ...(options.resumability || {}) }, auth: authConfig }; logger.debug(`Creating HttpStreamTransport with effective responseMode: ${httpConfig.responseMode}`); transport = new HttpStreamTransport(httpConfig); break; } case "stdio": default: if (this.transportConfig.type !== "stdio") { logger.warn(`Unsupported type '${this.transportConfig.type}', defaulting to stdio.`); } transport = new StdioServerTransport(); break; } transport.onclose = () => { logger.info(`Transport (${transport.type}) closed.`); if (this.isRunning) { this.stop().catch(error => { logger.error(`Shutdown error after transport close: ${error}`); process.exit(1); }); } }; transport.onerror = (error) => { logger.error(`Transport (${transport.type}) error: ${error.message}\n${error.stack}`); }; return transport; } readPackageJson() { try { const projectRoot = process.cwd(); const packagePath = join(projectRoot, "package.json"); try { const packageContent = readFileSync(packagePath, "utf-8"); const packageJson = JSON.parse(packageContent); logger.debug(`Successfully read package.json from project root: ${packagePath}`); return packageJson; } catch (error) { logger.warn(`Could not read package.json from project root: ${error}`); return null; } } catch (error) { logger.warn(`Could not read package.json: ${error}`); return null; } } getDefaultName() { const packageJson = this.readPackageJson(); if (packageJson?.name) { return packageJson.name; } logger.error("Couldn't find project name in package json"); return "unnamed-mcp-server"; } getDefaultVersion() { const packageJson = this.readPackageJson(); if (packageJson?.version) { return packageJson.version; } return "0.0.0"; } setupHandlers() { // TODO: Replace 'any' with the specific inferred request type from the SDK schema if available this.server.setRequestHandler(ListToolsRequestSchema, async (request) => { logger.debug(`Received ListTools request: ${JSON.stringify(request)}`); const tools = Array.from(this.toolsMap.values()).map((tool) => tool.toolDefinition); logger.debug(`Found ${tools.length} tools to return`); logger.debug(`Tool definitions: ${JSON.stringify(tools)}`); const response = { tools: tools, nextCursor: undefined }; logger.debug(`Sending ListTools response: ${JSON.stringify(response)}`); return response; }); // TODO: Replace 'any' with the specific inferred request type from the SDK schema if available this.server.setRequestHandler(CallToolRequestSchema, async (request) => { logger.debug(`Tool call request received for: ${request.params.name}`); logger.debug(`Tool call arguments: ${JSON.stringify(request.params.arguments)}`); const tool = this.toolsMap.get(request.params.name); if (!tool) { const availableTools = Array.from(this.toolsMap.keys()); const errorMsg = `Unknown tool: ${request.params.name}. Available tools: ${availableTools.join(", ")}`; logger.error(errorMsg); throw new Error(errorMsg); } try { logger.debug(`Executing tool: ${tool.name}`); const toolRequest = { params: request.params, method: "tools/call", }; const result = await tool.toolCall(toolRequest); logger.debug(`Tool execution successful: ${JSON.stringify(result)}`); return result; } catch (error) { const errorMsg = `Tool execution failed: ${error}`; logger.error(errorMsg); throw new Error(errorMsg); } }); if (this.capabilities.prompts) { // No request parameter for ListPrompts this.server.setRequestHandler(ListPromptsRequestSchema, async () => { return { prompts: Array.from(this.promptsMap.values()).map((prompt) => prompt.promptDefinition), }; }); // TODO: Replace 'any' with the specific inferred request type from the SDK schema if available this.server.setRequestHandler(GetPromptRequestSchema, async (request) => { const prompt = this.promptsMap.get(request.params.name); if (!prompt) { throw new Error(`Unknown prompt: ${request.params.name}. Available prompts: ${Array.from(this.promptsMap.keys()).join(", ")}`); } return { messages: await prompt.getMessages(request.params.arguments), }; }); } if (this.capabilities.resources) { this.server.setRequestHandler(ListResourcesRequestSchema, async () => { return { resources: Array.from(this.resourcesMap.values()).map((resource) => resource.resourceDefinition), }; }); // TODO: Replace 'any' with the specific inferred request type from the SDK schema if available this.server.setRequestHandler(ReadResourceRequestSchema, async (request) => { const resource = this.resourcesMap.get(request.params.uri); if (!resource) { throw new Error(`Unknown resource: ${request.params.uri}. Available resources: ${Array.from(this.resourcesMap.keys()).join(", ")}`); } return { contents: await resource.read(), }; }); this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { logger.debug(`Received ListResourceTemplates request`); // For now, return an empty list as requested const response = { resourceTemplates: [], nextCursor: undefined }; logger.debug(`Sending ListResourceTemplates response: ${JSON.stringify(response)}`); return response; }); // TODO: Replace 'any' with the specific inferred request type from the SDK schema if available this.server.setRequestHandler(SubscribeRequestSchema, async (request) => { const resource = this.resourcesMap.get(request.params.uri); if (!resource) { throw new Error(`Unknown resource: ${request.params.uri}`); } if (!resource.subscribe) { throw new Error(`Resource ${request.params.uri} does not support subscriptions`); } await resource.subscribe(); return {}; }); // TODO: Replace 'any' with the specific inferred request type from the SDK schema if available this.server.setRequestHandler(UnsubscribeRequestSchema, async (request) => { const resource = this.resourcesMap.get(request.params.uri); if (!resource) { throw new Error(`Unknown resource: ${request.params.uri}`); } if (!resource.unsubscribe) { throw new Error(`Resource ${request.params.uri} does not support subscriptions`); } await resource.unsubscribe(); return {}; }); } } async detectCapabilities() { if (await this.toolLoader.hasTools()) { this.capabilities.tools = {}; logger.debug("Tools capability enabled"); } if (await this.promptLoader.hasPrompts()) { this.capabilities.prompts = {}; logger.debug("Prompts capability enabled"); } if (await this.resourceLoader.hasResources()) { this.capabilities.resources = {}; logger.debug("Resources capability enabled"); } this.server.updateCapabilities?.(this.capabilities); logger.debug(`Capabilities updated: ${JSON.stringify(this.capabilities)}`); return this.capabilities; } getSdkVersion() { try { const sdkSpecificFile = require.resolve("@modelcontextprotocol/sdk/server/index.js"); const sdkRootDir = resolve(dirname(sdkSpecificFile), '..', '..', '..'); const correctPackageJsonPath = join(sdkRootDir, "package.json"); const packageContent = readFileSync(correctPackageJsonPath, "utf-8"); const packageJson = JSON.parse(packageContent); if (packageJson?.version) { logger.debug(`Found SDK version: ${packageJson.version}`); return packageJson.version; } else { logger.warn("Could not determine SDK version from its package.json."); return "unknown"; } } catch (error) { logger.warn(`Failed to read SDK package.json: ${error.message}`); return "unknown"; } } async start() { try { if (this.isRunning) { throw new Error("Server is already running"); } this.isRunning = true; const frameworkPackageJson = require('../../package.json'); const frameworkVersion = frameworkPackageJson.version || 'unknown'; const sdkVersion = this.getSdkVersion(); logger.info(`Starting MCP server (Framework: ${frameworkVersion}, SDK: ${sdkVersion})...`); const tools = await this.toolLoader.loadTools(); this.toolsMap = new Map(tools.map((tool) => [tool.name, tool])); const prompts = await this.promptLoader.loadPrompts(); this.promptsMap = new Map(prompts.map((prompt) => [prompt.name, prompt])); const resources = await this.resourceLoader.loadResources(); this.resourcesMap = new Map(resources.map((resource) => [resource.uri, resource])); await this.detectCapabilities(); logger.info(`Capabilities detected: ${JSON.stringify(this.capabilities)}`); this.setupHandlers(); this.transport = this.createTransport(); logger.info(`Connecting transport (${this.transport.type}) to SDK Server...`); await this.server.connect(this.transport); logger.info(`Started ${this.serverName}@${this.serverVersion} successfully on transport ${this.transport.type}`); logger.info(`Tools (${tools.length}): ${tools.map(t => t.name).join(', ') || 'None'}`); if (this.capabilities.prompts) { logger.info(`Prompts (${prompts.length}): ${prompts.map(p => p.name).join(', ') || 'None'}`); } if (this.capabilities.resources) { logger.info(`Resources (${resources.length}): ${resources.map(r => r.uri).join(', ') || 'None'}`); } const shutdownHandler = async (signal) => { if (!this.isRunning) return; logger.info(`Received ${signal}. Shutting down...`); try { await this.stop(); } catch (e) { logger.error(`Shutdown error via ${signal}: ${e.message}`); process.exit(1); } }; process.on('SIGINT', () => shutdownHandler('SIGINT')); process.on('SIGTERM', () => shutdownHandler('SIGTERM')); this.shutdownPromise = new Promise((resolve) => { this.shutdownResolve = resolve; }); logger.info("Server running and ready."); await this.shutdownPromise; } catch (error) { logger.error(`Server failed to start: ${error.message}\n${error.stack}`); this.isRunning = false; throw error; } } async stop() { if (!this.isRunning) { logger.debug("Stop called, but server not running."); return; } try { logger.info("Stopping server..."); let transportError = null; let sdkServerError = null; if (this.transport) { try { logger.debug(`Closing transport (${this.transport.type})...`); await this.transport.close(); logger.info(`Transport closed.`); } catch (e) { transportError = e; logger.error(`Error closing transport: ${e.message}`); } this.transport = undefined; } if (this.server) { try { logger.debug("Closing SDK Server..."); await this.server.close(); logger.info("SDK Server closed."); } catch (e) { sdkServerError = e; logger.error(`Error closing SDK Server: ${e.message}`); } } this.isRunning = false; if (this.shutdownResolve) { this.shutdownResolve(); logger.debug("Shutdown promise resolved."); } else { logger.warn("Shutdown resolve function not found."); } if (transportError || sdkServerError) { logger.error("Errors occurred during server stop."); throw new Error(`Server stop failed. TransportError: ${transportError?.message}, SDKServerError: ${sdkServerError?.message}`); } logger.info("MCP server stopped successfully."); } catch (error) { logger.error(`Error stopping server: ${error}`); throw error; } } get IsRunning() { return this.isRunning; } }