UNPKG

mcp-framework

Version:

Framework for building Model Context Protocol (MCP) servers in Typescript

413 lines (412 loc) 18.4 kB
import { Server } from '@modelcontextprotocol/sdk/server/index.js'; import { CallToolRequestSchema, ListToolsRequestSchema, ListPromptsRequestSchema, GetPromptRequestSchema, ListResourcesRequestSchema, ListResourceTemplatesRequestSchema, ReadResourceRequestSchema, SubscribeRequestSchema, UnsubscribeRequestSchema, } from '@modelcontextprotocol/sdk/types.js'; import { readFileSync } from 'fs'; import { join, resolve, dirname } from 'path'; import { logger } from './Logger.js'; import { ToolLoader } from '../loaders/toolLoader.js'; import { PromptLoader } from '../loaders/promptLoader.js'; import { ResourceLoader } from '../loaders/resourceLoader.js'; import { StdioServerTransport } from '../transports/stdio/server.js'; import { SSEServerTransport } from '../transports/sse/server.js'; import { DEFAULT_SSE_CONFIG } from '../transports/sse/types.js'; import { HttpStreamTransport } from '../transports/http/server.js'; import { DEFAULT_HTTP_STREAM_CONFIG } from '../transports/http/types.js'; import { DEFAULT_CORS_CONFIG } from '../transports/sse/types.js'; import { createRequire } from 'module'; const require = createRequire(import.meta.url); export class MCPServer { server; toolsMap = new Map(); promptsMap = new Map(); resourcesMap = new Map(); toolLoader; promptLoader; resourceLoader; serverName; serverVersion; basePath; transportConfig; capabilities = {}; isRunning = false; transport; shutdownPromise; shutdownResolve; constructor(config = {}) { this.basePath = this.resolveBasePath(config.basePath); this.serverName = config.name ?? this.getDefaultName(); this.serverVersion = config.version ?? this.getDefaultVersion(); this.transportConfig = config.transport ?? { type: 'stdio' }; if (this.transportConfig.auth && this.transportConfig.options) { this.transportConfig.options.auth = this.transportConfig.auth; } else if (this.transportConfig.auth && !this.transportConfig.options) { this.transportConfig.options = { auth: this.transportConfig.auth }; } logger.info(`Initializing MCP Server: ${this.serverName}@${this.serverVersion}`); logger.debug(`Base path: ${this.basePath}`); logger.debug(`Transport config: ${JSON.stringify(this.transportConfig)}`); this.toolLoader = new ToolLoader(this.basePath); this.promptLoader = new PromptLoader(this.basePath); this.resourceLoader = new ResourceLoader(this.basePath); } resolveBasePath(configPath) { if (configPath) { return configPath; } if (process.argv[1]) { return dirname(process.argv[1]); } return process.cwd(); } createTransport() { logger.debug(`Creating transport: ${this.transportConfig.type}`); let transport; const options = this.transportConfig.options || {}; const authConfig = this.transportConfig.auth ?? options.auth; switch (this.transportConfig.type) { case 'sse': { const sseConfig = { ...DEFAULT_SSE_CONFIG, ...options, cors: { ...DEFAULT_CORS_CONFIG, ...options.cors }, auth: authConfig, }; transport = new SSEServerTransport(sseConfig); break; } case 'http-stream': { const httpConfig = { ...DEFAULT_HTTP_STREAM_CONFIG, ...options, cors: { ...DEFAULT_CORS_CONFIG, ...(options.cors || {}), }, auth: authConfig, }; logger.debug(`Creating HttpStreamTransport. response mode: ${httpConfig.responseMode}`); transport = new HttpStreamTransport(httpConfig); break; } case 'stdio': default: if (this.transportConfig.type !== 'stdio') { logger.warn(`Unsupported type '${this.transportConfig.type}', defaulting to stdio.`); } transport = new StdioServerTransport(); break; } transport.onclose = () => { logger.info(`Transport (${transport.type}) closed.`); if (this.isRunning) { this.stop().catch((error) => { logger.error(`Shutdown error after transport close: ${error}`); process.exit(1); }); } }; transport.onerror = (error) => { logger.error(`Transport (${transport.type}) error: ${error.message}\n${error.stack}`); }; return transport; } readPackageJson() { try { const projectRoot = process.cwd(); const packagePath = join(projectRoot, 'package.json'); try { const packageContent = readFileSync(packagePath, 'utf-8'); const packageJson = JSON.parse(packageContent); logger.debug(`Successfully read package.json from project root: ${packagePath}`); return packageJson; } catch (error) { logger.warn(`Could not read package.json from project root: ${error}`); return null; } } catch (error) { logger.warn(`Could not read package.json: ${error}`); return null; } } getDefaultName() { const packageJson = this.readPackageJson(); if (packageJson?.name) { return packageJson.name; } logger.error("Couldn't find project name in package json"); return 'unnamed-mcp-server'; } getDefaultVersion() { const packageJson = this.readPackageJson(); if (packageJson?.version) { return packageJson.version; } return '0.0.0'; } setupHandlers(server) { const targetServer = server || this.server; targetServer.setRequestHandler(ListToolsRequestSchema, async (request) => { logger.debug(`Received ListTools request: ${JSON.stringify(request)}`); const tools = Array.from(this.toolsMap.values()).map((tool) => tool.toolDefinition); logger.debug(`Found ${tools.length} tools to return`); logger.debug(`Tool definitions: ${JSON.stringify(tools)}`); const response = { tools: tools, nextCursor: undefined, }; logger.debug(`Sending ListTools response: ${JSON.stringify(response)}`); return response; }); targetServer.setRequestHandler(CallToolRequestSchema, async (request) => { logger.debug(`Tool call request received for: ${request.params.name}`); logger.debug(`Tool call arguments: ${JSON.stringify(request.params.arguments)}`); const tool = this.toolsMap.get(request.params.name); if (!tool) { const availableTools = Array.from(this.toolsMap.keys()); const errorMsg = `Unknown tool: ${request.params.name}. Available tools: ${availableTools.join(', ')}`; logger.error(errorMsg); throw new Error(errorMsg); } try { logger.debug(`Executing tool: ${tool.name}`); const toolRequest = { params: request.params, method: 'tools/call', }; const result = await tool.toolCall(toolRequest); logger.debug(`Tool execution successful: ${JSON.stringify(result)}`); return result; } catch (error) { const errorMsg = `Tool execution failed: ${error}`; logger.error(errorMsg); throw new Error(errorMsg); } }); if (this.capabilities.prompts) { targetServer.setRequestHandler(ListPromptsRequestSchema, async () => { return { prompts: Array.from(this.promptsMap.values()).map((prompt) => prompt.promptDefinition), }; }); targetServer.setRequestHandler(GetPromptRequestSchema, async (request) => { const prompt = this.promptsMap.get(request.params.name); if (!prompt) { throw new Error(`Unknown prompt: ${request.params.name}. Available prompts: ${Array.from(this.promptsMap.keys()).join(', ')}`); } return { messages: await prompt.getMessages(request.params.arguments), }; }); } if (this.capabilities.resources) { targetServer.setRequestHandler(ListResourcesRequestSchema, async () => { return { resources: Array.from(this.resourcesMap.values()).map((resource) => resource.resourceDefinition), }; }); targetServer.setRequestHandler(ReadResourceRequestSchema, async (request) => { const resource = this.resourcesMap.get(request.params.uri); if (!resource) { throw new Error(`Unknown resource: ${request.params.uri}. Available resources: ${Array.from(this.resourcesMap.keys()).join(', ')}`); } return { contents: await resource.read(), }; }); targetServer.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { logger.debug(`Received ListResourceTemplates request`); const response = { resourceTemplates: [], nextCursor: undefined, }; logger.debug(`Sending ListResourceTemplates response: ${JSON.stringify(response)}`); return response; }); targetServer.setRequestHandler(SubscribeRequestSchema, async (request) => { const resource = this.resourcesMap.get(request.params.uri); if (!resource) { throw new Error(`Unknown resource: ${request.params.uri}`); } if (!resource.subscribe) { throw new Error(`Resource ${request.params.uri} does not support subscriptions`); } await resource.subscribe(); return {}; }); targetServer.setRequestHandler(UnsubscribeRequestSchema, async (request) => { const resource = this.resourcesMap.get(request.params.uri); if (!resource) { throw new Error(`Unknown resource: ${request.params.uri}`); } if (!resource.unsubscribe) { throw new Error(`Resource ${request.params.uri} does not support subscriptions`); } await resource.unsubscribe(); return {}; }); } } async detectCapabilities() { if (await this.toolLoader.hasTools()) { this.capabilities.tools = {}; logger.debug('Tools capability enabled'); } if (await this.promptLoader.hasPrompts()) { this.capabilities.prompts = {}; logger.debug('Prompts capability enabled'); } if (await this.resourceLoader.hasResources()) { this.capabilities.resources = {}; logger.debug('Resources capability enabled'); } logger.debug(`Capabilities detected: ${JSON.stringify(this.capabilities)}`); return this.capabilities; } getSdkVersion() { try { const sdkSpecificFile = require.resolve('@modelcontextprotocol/sdk/server/index.js'); const sdkRootDir = resolve(dirname(sdkSpecificFile), '..', '..', '..'); const correctPackageJsonPath = join(sdkRootDir, 'package.json'); const packageContent = readFileSync(correctPackageJsonPath, 'utf-8'); const packageJson = JSON.parse(packageContent); if (packageJson?.version) { logger.debug(`Found SDK version: ${packageJson.version}`); return packageJson.version; } else { logger.warn('Could not determine SDK version from its package.json.'); return 'unknown'; } } catch (error) { logger.warn(`Failed to read SDK package.json: ${error.message}`); return 'unknown'; } } async start() { try { if (this.isRunning) { throw new Error('Server is already running'); } this.isRunning = true; const frameworkPackageJson = require('../../package.json'); const frameworkVersion = frameworkPackageJson.version || 'unknown'; const sdkVersion = this.getSdkVersion(); logger.info(`Starting MCP server: (Framework: ${frameworkVersion}, SDK: ${sdkVersion})...`); const tools = await this.toolLoader.loadTools(); this.toolsMap = new Map(tools.map((tool) => [tool.name, tool])); for (const tool of tools) { if ('validate' in tool && typeof tool.validate === 'function') { try { tool.validate(); } catch (error) { logger.error(`Tool validation failed for '${tool.name}': ${error.message}`); throw new Error(`Tool validation failed for '${tool.name}': ${error.message}`); } } } const prompts = await this.promptLoader.loadPrompts(); this.promptsMap = new Map(prompts.map((prompt) => [prompt.name, prompt])); const resources = await this.resourceLoader.loadResources(); this.resourcesMap = new Map(resources.map((resource) => [resource.uri, resource])); await this.detectCapabilities(); logger.info(`Capabilities detected: ${JSON.stringify(this.capabilities)}`); this.server = new Server({ name: this.serverName, version: this.serverVersion }, { capabilities: this.capabilities }); logger.debug(`SDK Server instance created with capabilities: ${JSON.stringify(this.capabilities)}`); this.setupHandlers(); this.transport = this.createTransport(); logger.info(`Connecting transport (${this.transport.type}) to SDK Server...`); await this.server.connect(this.transport); logger.info(`Started ${this.serverName}@${this.serverVersion} successfully on transport ${this.transport.type}`); logger.info(`Tools (${tools.length}): ${tools.map((t) => t.name).join(', ') || 'None'}`); if (this.capabilities.prompts) { logger.info(`Prompts (${prompts.length}): ${prompts.map((p) => p.name).join(', ') || 'None'}`); } if (this.capabilities.resources) { logger.info(`Resources (${resources.length}): ${resources.map((r) => r.uri).join(', ') || 'None'}`); } const shutdownHandler = async (signal) => { if (!this.isRunning) return; logger.info(`Received ${signal}. Shutting down...`); try { await this.stop(); } catch (e) { logger.error(`Shutdown error via ${signal}: ${e.message}`); process.exit(1); } }; process.on('SIGINT', () => shutdownHandler('SIGINT')); process.on('SIGTERM', () => shutdownHandler('SIGTERM')); this.shutdownPromise = new Promise((resolve) => { this.shutdownResolve = resolve; }); logger.info('Server running and ready.'); await this.shutdownPromise; } catch (error) { logger.error(`Server failed to start: ${error.message}\n${error.stack}`); this.isRunning = false; throw error; } } async stop() { if (!this.isRunning) { logger.debug('Stop called, but server not running.'); return; } try { logger.info('Stopping server...'); let transportError = null; let sdkServerError = null; if (this.transport) { try { logger.debug(`Closing transport (${this.transport.type})...`); await this.transport.close(); logger.info(`Transport closed.`); } catch (e) { transportError = e; logger.error(`Error closing transport: ${e.message}`); } this.transport = undefined; } if (this.server) { try { logger.debug('Closing SDK Server...'); await this.server.close(); logger.info('SDK Server closed.'); } catch (e) { sdkServerError = e; logger.error(`Error closing SDK Server: ${e.message}`); } } this.isRunning = false; if (this.shutdownResolve) { this.shutdownResolve(); logger.debug('Shutdown promise resolved.'); } else { logger.warn('Shutdown resolve function not found.'); } if (transportError || sdkServerError) { logger.error('Errors occurred during server stop.'); throw new Error(`Server stop failed. TransportError: ${transportError?.message}, SDKServerError: ${sdkServerError?.message}`); } logger.info('MCP server stopped successfully.'); } catch (error) { logger.error(`Error stopping server: ${error}`); throw error; } } get IsRunning() { return this.isRunning; } }