mcp-framework
Version:
Framework for building Model Context Protocol (MCP) servers in Typescript
427 lines (426 loc) • 19.3 kB
JavaScript
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import { CallToolRequestSchema, ListToolsRequestSchema, ListPromptsRequestSchema, GetPromptRequestSchema, ListResourcesRequestSchema, ListResourceTemplatesRequestSchema, ReadResourceRequestSchema, SubscribeRequestSchema, UnsubscribeRequestSchema, } from "@modelcontextprotocol/sdk/types.js";
import { readFileSync } from "fs";
import { join, resolve, dirname } from "path";
import { logger } from "./Logger.js";
import { ToolLoader } from "../loaders/toolLoader.js";
import { PromptLoader } from "../loaders/promptLoader.js";
import { ResourceLoader } from "../loaders/resourceLoader.js";
import { StdioServerTransport } from "../transports/stdio/server.js";
import { SSEServerTransport } from "../transports/sse/server.js";
import { DEFAULT_SSE_CONFIG } from "../transports/sse/types.js";
import { HttpStreamTransport } from "../transports/http/server.js";
import { DEFAULT_HTTP_STREAM_CONFIG } from "../transports/http/types.js";
import { DEFAULT_CORS_CONFIG } from "../transports/sse/types.js";
import { createRequire } from 'module';
const require = createRequire(import.meta.url);
function isRequest(msg) {
return msg && typeof msg.method === 'string' && msg.jsonrpc === "2.0" && 'id' in msg;
}
function isResponse(msg) {
return msg && msg.jsonrpc === "2.0" && 'id' in msg && ('result' in msg || 'error' in msg);
}
function isNotification(msg) {
return msg && typeof msg.method === 'string' && msg.jsonrpc === "2.0" && !('id' in msg);
}
export class MCPServer {
server;
toolsMap = new Map();
promptsMap = new Map();
resourcesMap = new Map();
toolLoader;
promptLoader;
resourceLoader;
serverName;
serverVersion;
basePath;
transportConfig;
capabilities = {}; // Initialize as empty
isRunning = false;
transport;
shutdownPromise;
shutdownResolve;
constructor(config = {}) {
this.basePath = this.resolveBasePath(config.basePath);
this.serverName = config.name ?? this.getDefaultName();
this.serverVersion = config.version ?? this.getDefaultVersion();
this.transportConfig = config.transport ?? { type: "stdio" };
if (this.transportConfig.auth && this.transportConfig.options) {
this.transportConfig.options.auth = this.transportConfig.auth;
}
else if (this.transportConfig.auth && !this.transportConfig.options) {
this.transportConfig.options = { auth: this.transportConfig.auth };
}
logger.info(`Initializing MCP Server: ${this.serverName}@${this.serverVersion}`);
logger.debug(`Base path: ${this.basePath}`);
logger.debug(`Transport config: ${JSON.stringify(this.transportConfig)}`);
this.toolLoader = new ToolLoader(this.basePath);
this.promptLoader = new PromptLoader(this.basePath);
this.resourceLoader = new ResourceLoader(this.basePath);
this.server = new Server({ name: this.serverName, version: this.serverVersion }, { capabilities: this.capabilities });
logger.debug(`SDK Server instance created.`);
}
resolveBasePath(configPath) {
if (configPath) {
return configPath;
}
if (process.argv[1]) {
return process.argv[1];
}
return process.cwd();
}
createTransport() {
logger.debug(`Creating transport: ${this.transportConfig.type}`);
let transport;
const options = this.transportConfig.options || {};
const authConfig = this.transportConfig.auth ?? options.auth;
switch (this.transportConfig.type) {
case "sse": {
const sseConfig = {
...DEFAULT_SSE_CONFIG,
...options,
cors: { ...DEFAULT_CORS_CONFIG, ...options.cors },
auth: authConfig
};
transport = new SSEServerTransport(sseConfig);
break;
}
case "http-stream": {
const httpConfig = {
...DEFAULT_HTTP_STREAM_CONFIG,
...options,
cors: {
...DEFAULT_CORS_CONFIG,
...(options.cors || {})
},
session: {
...DEFAULT_HTTP_STREAM_CONFIG.session,
...(options.session || {})
},
resumability: {
...DEFAULT_HTTP_STREAM_CONFIG.resumability,
...(options.resumability || {})
},
auth: authConfig
};
logger.debug(`Creating HttpStreamTransport with effective responseMode: ${httpConfig.responseMode}`);
transport = new HttpStreamTransport(httpConfig);
break;
}
case "stdio":
default:
if (this.transportConfig.type !== "stdio") {
logger.warn(`Unsupported type '${this.transportConfig.type}', defaulting to stdio.`);
}
transport = new StdioServerTransport();
break;
}
transport.onclose = () => {
logger.info(`Transport (${transport.type}) closed.`);
if (this.isRunning) {
this.stop().catch(error => {
logger.error(`Shutdown error after transport close: ${error}`);
process.exit(1);
});
}
};
transport.onerror = (error) => {
logger.error(`Transport (${transport.type}) error: ${error.message}\n${error.stack}`);
};
return transport;
}
readPackageJson() {
try {
const projectRoot = process.cwd();
const packagePath = join(projectRoot, "package.json");
try {
const packageContent = readFileSync(packagePath, "utf-8");
const packageJson = JSON.parse(packageContent);
logger.debug(`Successfully read package.json from project root: ${packagePath}`);
return packageJson;
}
catch (error) {
logger.warn(`Could not read package.json from project root: ${error}`);
return null;
}
}
catch (error) {
logger.warn(`Could not read package.json: ${error}`);
return null;
}
}
getDefaultName() {
const packageJson = this.readPackageJson();
if (packageJson?.name) {
return packageJson.name;
}
logger.error("Couldn't find project name in package json");
return "unnamed-mcp-server";
}
getDefaultVersion() {
const packageJson = this.readPackageJson();
if (packageJson?.version) {
return packageJson.version;
}
return "0.0.0";
}
setupHandlers() {
// TODO: Replace 'any' with the specific inferred request type from the SDK schema if available
this.server.setRequestHandler(ListToolsRequestSchema, async (request) => {
logger.debug(`Received ListTools request: ${JSON.stringify(request)}`);
const tools = Array.from(this.toolsMap.values()).map((tool) => tool.toolDefinition);
logger.debug(`Found ${tools.length} tools to return`);
logger.debug(`Tool definitions: ${JSON.stringify(tools)}`);
const response = {
tools: tools,
nextCursor: undefined
};
logger.debug(`Sending ListTools response: ${JSON.stringify(response)}`);
return response;
});
// TODO: Replace 'any' with the specific inferred request type from the SDK schema if available
this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
logger.debug(`Tool call request received for: ${request.params.name}`);
logger.debug(`Tool call arguments: ${JSON.stringify(request.params.arguments)}`);
const tool = this.toolsMap.get(request.params.name);
if (!tool) {
const availableTools = Array.from(this.toolsMap.keys());
const errorMsg = `Unknown tool: ${request.params.name}. Available tools: ${availableTools.join(", ")}`;
logger.error(errorMsg);
throw new Error(errorMsg);
}
try {
logger.debug(`Executing tool: ${tool.name}`);
const toolRequest = {
params: request.params,
method: "tools/call",
};
const result = await tool.toolCall(toolRequest);
logger.debug(`Tool execution successful: ${JSON.stringify(result)}`);
return result;
}
catch (error) {
const errorMsg = `Tool execution failed: ${error}`;
logger.error(errorMsg);
throw new Error(errorMsg);
}
});
if (this.capabilities.prompts) {
// No request parameter for ListPrompts
this.server.setRequestHandler(ListPromptsRequestSchema, async () => {
return {
prompts: Array.from(this.promptsMap.values()).map((prompt) => prompt.promptDefinition),
};
});
// TODO: Replace 'any' with the specific inferred request type from the SDK schema if available
this.server.setRequestHandler(GetPromptRequestSchema, async (request) => {
const prompt = this.promptsMap.get(request.params.name);
if (!prompt) {
throw new Error(`Unknown prompt: ${request.params.name}. Available prompts: ${Array.from(this.promptsMap.keys()).join(", ")}`);
}
return {
messages: await prompt.getMessages(request.params.arguments),
};
});
}
if (this.capabilities.resources) {
this.server.setRequestHandler(ListResourcesRequestSchema, async () => {
return {
resources: Array.from(this.resourcesMap.values()).map((resource) => resource.resourceDefinition),
};
});
// TODO: Replace 'any' with the specific inferred request type from the SDK schema if available
this.server.setRequestHandler(ReadResourceRequestSchema, async (request) => {
const resource = this.resourcesMap.get(request.params.uri);
if (!resource) {
throw new Error(`Unknown resource: ${request.params.uri}. Available resources: ${Array.from(this.resourcesMap.keys()).join(", ")}`);
}
return {
contents: await resource.read(),
};
});
this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => {
logger.debug(`Received ListResourceTemplates request`);
// For now, return an empty list as requested
const response = {
resourceTemplates: [],
nextCursor: undefined
};
logger.debug(`Sending ListResourceTemplates response: ${JSON.stringify(response)}`);
return response;
});
// TODO: Replace 'any' with the specific inferred request type from the SDK schema if available
this.server.setRequestHandler(SubscribeRequestSchema, async (request) => {
const resource = this.resourcesMap.get(request.params.uri);
if (!resource) {
throw new Error(`Unknown resource: ${request.params.uri}`);
}
if (!resource.subscribe) {
throw new Error(`Resource ${request.params.uri} does not support subscriptions`);
}
await resource.subscribe();
return {};
});
// TODO: Replace 'any' with the specific inferred request type from the SDK schema if available
this.server.setRequestHandler(UnsubscribeRequestSchema, async (request) => {
const resource = this.resourcesMap.get(request.params.uri);
if (!resource) {
throw new Error(`Unknown resource: ${request.params.uri}`);
}
if (!resource.unsubscribe) {
throw new Error(`Resource ${request.params.uri} does not support subscriptions`);
}
await resource.unsubscribe();
return {};
});
}
}
async detectCapabilities() {
if (await this.toolLoader.hasTools()) {
this.capabilities.tools = {};
logger.debug("Tools capability enabled");
}
if (await this.promptLoader.hasPrompts()) {
this.capabilities.prompts = {};
logger.debug("Prompts capability enabled");
}
if (await this.resourceLoader.hasResources()) {
this.capabilities.resources = {};
logger.debug("Resources capability enabled");
}
this.server.updateCapabilities?.(this.capabilities);
logger.debug(`Capabilities updated: ${JSON.stringify(this.capabilities)}`);
return this.capabilities;
}
getSdkVersion() {
try {
const sdkSpecificFile = require.resolve("@modelcontextprotocol/sdk/server/index.js");
const sdkRootDir = resolve(dirname(sdkSpecificFile), '..', '..', '..');
const correctPackageJsonPath = join(sdkRootDir, "package.json");
const packageContent = readFileSync(correctPackageJsonPath, "utf-8");
const packageJson = JSON.parse(packageContent);
if (packageJson?.version) {
logger.debug(`Found SDK version: ${packageJson.version}`);
return packageJson.version;
}
else {
logger.warn("Could not determine SDK version from its package.json.");
return "unknown";
}
}
catch (error) {
logger.warn(`Failed to read SDK package.json: ${error.message}`);
return "unknown";
}
}
async start() {
try {
if (this.isRunning) {
throw new Error("Server is already running");
}
this.isRunning = true;
const frameworkPackageJson = require('../../package.json');
const frameworkVersion = frameworkPackageJson.version || 'unknown';
const sdkVersion = this.getSdkVersion();
logger.info(`Starting MCP server (Framework: ${frameworkVersion}, SDK: ${sdkVersion})...`);
const tools = await this.toolLoader.loadTools();
this.toolsMap = new Map(tools.map((tool) => [tool.name, tool]));
const prompts = await this.promptLoader.loadPrompts();
this.promptsMap = new Map(prompts.map((prompt) => [prompt.name, prompt]));
const resources = await this.resourceLoader.loadResources();
this.resourcesMap = new Map(resources.map((resource) => [resource.uri, resource]));
await this.detectCapabilities();
logger.info(`Capabilities detected: ${JSON.stringify(this.capabilities)}`);
this.setupHandlers();
this.transport = this.createTransport();
logger.info(`Connecting transport (${this.transport.type}) to SDK Server...`);
await this.server.connect(this.transport);
logger.info(`Started ${this.serverName}@${this.serverVersion} successfully on transport ${this.transport.type}`);
logger.info(`Tools (${tools.length}): ${tools.map(t => t.name).join(', ') || 'None'}`);
if (this.capabilities.prompts) {
logger.info(`Prompts (${prompts.length}): ${prompts.map(p => p.name).join(', ') || 'None'}`);
}
if (this.capabilities.resources) {
logger.info(`Resources (${resources.length}): ${resources.map(r => r.uri).join(', ') || 'None'}`);
}
const shutdownHandler = async (signal) => {
if (!this.isRunning)
return;
logger.info(`Received ${signal}. Shutting down...`);
try {
await this.stop();
}
catch (e) {
logger.error(`Shutdown error via ${signal}: ${e.message}`);
process.exit(1);
}
};
process.on('SIGINT', () => shutdownHandler('SIGINT'));
process.on('SIGTERM', () => shutdownHandler('SIGTERM'));
this.shutdownPromise = new Promise((resolve) => {
this.shutdownResolve = resolve;
});
logger.info("Server running and ready.");
await this.shutdownPromise;
}
catch (error) {
logger.error(`Server failed to start: ${error.message}\n${error.stack}`);
this.isRunning = false;
throw error;
}
}
async stop() {
if (!this.isRunning) {
logger.debug("Stop called, but server not running.");
return;
}
try {
logger.info("Stopping server...");
let transportError = null;
let sdkServerError = null;
if (this.transport) {
try {
logger.debug(`Closing transport (${this.transport.type})...`);
await this.transport.close();
logger.info(`Transport closed.`);
}
catch (e) {
transportError = e;
logger.error(`Error closing transport: ${e.message}`);
}
this.transport = undefined;
}
if (this.server) {
try {
logger.debug("Closing SDK Server...");
await this.server.close();
logger.info("SDK Server closed.");
}
catch (e) {
sdkServerError = e;
logger.error(`Error closing SDK Server: ${e.message}`);
}
}
this.isRunning = false;
if (this.shutdownResolve) {
this.shutdownResolve();
logger.debug("Shutdown promise resolved.");
}
else {
logger.warn("Shutdown resolve function not found.");
}
if (transportError || sdkServerError) {
logger.error("Errors occurred during server stop.");
throw new Error(`Server stop failed. TransportError: ${transportError?.message}, SDKServerError: ${sdkServerError?.message}`);
}
logger.info("MCP server stopped successfully.");
}
catch (error) {
logger.error(`Error stopping server: ${error}`);
throw error;
}
}
get IsRunning() {
return this.isRunning;
}
}