UNPKG

mcp-framework

Version:

Framework for building Model Context Protocol (MCP) servers in Typescript

322 lines (321 loc) 13.2 kB
import { randomUUID } from "node:crypto"; import { createServer } from "node:http"; import contentType from "content-type"; import getRawBody from "raw-body"; import { APIKeyAuthProvider } from "../../auth/providers/apikey.js"; import { DEFAULT_AUTH_ERROR } from "../../auth/types.js"; import { AbstractTransport } from "../base.js"; import { DEFAULT_SSE_CONFIG, DEFAULT_CORS_CONFIG } from "./types.js"; import { logger } from "../../core/Logger.js"; import { getRequestHeader, setResponseHeaders } from "../../utils/headers.js"; import { PING_SSE_MESSAGE } from "../utils/ping-message.js"; const SSE_HEADERS = { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", "Connection": "keep-alive" }; export class SSEServerTransport extends AbstractTransport { type = "sse"; _server; _sseResponse; _sessionId; _config; _keepAliveInterval; constructor(config = {}) { super(); this._sessionId = randomUUID(); this._config = { ...DEFAULT_SSE_CONFIG, ...config }; logger.debug(`SSE transport configured with: ${JSON.stringify({ ...this._config, auth: this._config.auth ? { provider: this._config.auth.provider.constructor.name, endpoints: this._config.auth.endpoints } : undefined })}`); } getCorsHeaders(includeMaxAge = false) { const corsConfig = { allowOrigin: DEFAULT_CORS_CONFIG.allowOrigin, allowMethods: DEFAULT_CORS_CONFIG.allowMethods, allowHeaders: DEFAULT_CORS_CONFIG.allowHeaders, exposeHeaders: DEFAULT_CORS_CONFIG.exposeHeaders, maxAge: DEFAULT_CORS_CONFIG.maxAge, ...this._config.cors }; const headers = { "Access-Control-Allow-Origin": corsConfig.allowOrigin, "Access-Control-Allow-Methods": corsConfig.allowMethods, "Access-Control-Allow-Headers": corsConfig.allowHeaders, "Access-Control-Expose-Headers": corsConfig.exposeHeaders }; if (includeMaxAge) { headers["Access-Control-Max-Age"] = corsConfig.maxAge; } return headers; } async start() { if (this._server) { throw new Error("SSE transport already started"); } return new Promise((resolve) => { this._server = createServer(async (req, res) => { try { await this.handleRequest(req, res); } catch (error) { logger.error(`Error handling request: ${error}`); res.writeHead(500).end("Internal Server Error"); } }); this._server.listen(this._config.port, () => { logger.info(`SSE transport listening on port ${this._config.port}`); resolve(); }); this._server.on("error", (error) => { logger.error(`SSE server error: ${error}`); this._onerror?.(error); }); this._server.on("close", () => { logger.info("SSE server closed"); this._onclose?.(); }); }); } async handleRequest(req, res) { logger.debug(`Incoming request: ${req.method} ${req.url}`); if (req.method === "OPTIONS") { setResponseHeaders(res, this.getCorsHeaders(true)); res.writeHead(204).end(); return; } setResponseHeaders(res, this.getCorsHeaders()); const url = new URL(req.url, `http://${req.headers.host}`); const sessionId = url.searchParams.get("sessionId"); if (req.method === "GET" && url.pathname === this._config.endpoint) { if (this._config.auth?.endpoints?.sse) { const isAuthenticated = await this.handleAuthentication(req, res, "SSE connection"); if (!isAuthenticated) return; } if (this._sseResponse?.writableEnded) { this._sseResponse = undefined; } if (this._sseResponse) { logger.warn("SSE connection already established; closing the old connection to allow a new one."); this._sseResponse.end(); this.cleanupConnection(); } this.setupSSEConnection(res); return; } if (req.method === "POST" && url.pathname === this._config.messageEndpoint) { if (sessionId !== this._sessionId) { logger.warn(`Invalid session ID received: ${sessionId}, expected: ${this._sessionId}`); res.writeHead(403).end("Invalid session ID"); return; } if (this._config.auth?.endpoints?.messages !== false) { const isAuthenticated = await this.handleAuthentication(req, res, "message"); if (!isAuthenticated) return; } await this.handlePostMessage(req, res); return; } res.writeHead(404).end("Not Found"); } async handleAuthentication(req, res, context) { if (!this._config.auth?.provider) { return true; } const isApiKey = this._config.auth.provider instanceof APIKeyAuthProvider; if (isApiKey) { const provider = this._config.auth.provider; const headerValue = getRequestHeader(req.headers, provider.getHeaderName()); if (!headerValue) { const error = provider.getAuthError?.() || DEFAULT_AUTH_ERROR; res.setHeader("WWW-Authenticate", `ApiKey realm="MCP Server", header="${provider.getHeaderName()}"`); res.writeHead(error.status).end(JSON.stringify({ error: error.message, status: error.status, type: "authentication_error" })); return false; } } const authResult = await this._config.auth.provider.authenticate(req); if (!authResult) { const error = this._config.auth.provider.getAuthError?.() || DEFAULT_AUTH_ERROR; logger.warn(`Authentication failed for ${context}:`); logger.warn(`- Client IP: ${req.socket.remoteAddress}`); logger.warn(`- Error: ${error.message}`); if (isApiKey) { const provider = this._config.auth.provider; res.setHeader("WWW-Authenticate", `ApiKey realm="MCP Server", header="${provider.getHeaderName()}"`); } res.writeHead(error.status).end(JSON.stringify({ error: error.message, status: error.status, type: "authentication_error" })); return false; } logger.info(`Authentication successful for ${context}:`); logger.info(`- Client IP: ${req.socket.remoteAddress}`); logger.info(`- Auth Type: ${this._config.auth.provider.constructor.name}`); return true; } setupSSEConnection(res) { logger.debug(`Setting up SSE connection for session: ${this._sessionId}`); const headers = { ...SSE_HEADERS, ...this.getCorsHeaders(), ...this._config.headers }; setResponseHeaders(res, headers); logger.debug(`SSE headers set: ${JSON.stringify(headers)}`); if (res.socket) { res.socket.setNoDelay(true); res.socket.setTimeout(0); res.socket.setKeepAlive(true, 1000); logger.debug('Socket optimized for SSE connection'); } const endpointUrl = `${this._config.messageEndpoint}?sessionId=${this._sessionId}`; logger.debug(`Sending endpoint URL: ${endpointUrl}`); res.write(`event: endpoint\ndata: ${endpointUrl}\n\n`); logger.debug('Sending initial keep-alive'); this._keepAliveInterval = setInterval(() => { if (this._sseResponse && !this._sseResponse.writableEnded) { try { this._sseResponse.write(PING_SSE_MESSAGE); } catch (error) { logger.error(`Error sending keep-alive: ${error}`); this.cleanupConnection(); } } }, 15000); this._sseResponse = res; const cleanup = () => this.cleanupConnection(); res.on("close", () => { logger.info(`SSE connection closed for session: ${this._sessionId}`); cleanup(); }); res.on("error", (error) => { logger.error(`SSE connection error for session ${this._sessionId}: ${error}`); this._onerror?.(error); cleanup(); }); res.on("end", () => { logger.info(`SSE connection ended for session: ${this._sessionId}`); cleanup(); }); logger.info(`SSE connection established successfully for session: ${this._sessionId}`); } async handlePostMessage(req, res) { if (!this._sseResponse || this._sseResponse.writableEnded) { logger.warn(`Rejecting message: no active SSE connection for session ${this._sessionId}`); res.writeHead(409).end("SSE connection not established"); return; } let currentMessage = {}; try { const rawMessage = req.body || await (async () => { const ct = contentType.parse(req.headers["content-type"] ?? ""); if (ct.type !== "application/json") { throw new Error(`Unsupported content-type: ${ct.type}`); } const rawBody = await getRawBody(req, { limit: this._config.maxMessageSize, encoding: ct.parameters.charset ?? "utf-8" }); const parsed = JSON.parse(rawBody.toString()); logger.debug(`Received message: ${JSON.stringify(parsed)}`); return parsed; })(); const { id, method, params } = rawMessage; logger.debug(`Parsed message - ID: ${id}, Method: ${method}`); const rpcMessage = { jsonrpc: "2.0", id: id, method: method, params: params }; currentMessage = { id: id, method: method }; logger.debug(`Processing RPC message: ${JSON.stringify({ id: id, method: method, params: params })}`); if (!this._onmessage) { throw new Error("No message handler registered"); } await this._onmessage(rpcMessage); res.writeHead(202).end("Accepted"); logger.debug(`Successfully processed message ${rpcMessage.id}`); } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); logger.error(`Error handling message for session ${this._sessionId}:`); logger.error(`- Error: ${errorMessage}`); logger.error(`- Method: ${currentMessage.method || "unknown"}`); logger.error(`- Message ID: ${currentMessage.id || "unknown"}`); const errorResponse = { jsonrpc: "2.0", id: currentMessage.id || null, error: { code: -32000, message: errorMessage, data: { method: currentMessage.method || "unknown", sessionId: this._sessionId, connectionActive: Boolean(this._sseResponse), type: "message_handler_error" } } }; res.writeHead(400).end(JSON.stringify(errorResponse)); this._onerror?.(error); } } async send(message) { if (!this._sseResponse || this._sseResponse.writableEnded) { throw new Error("SSE connection not established"); } this._sseResponse.write(`data: ${JSON.stringify(message)}\n\n`); } async close() { if (this._sseResponse && !this._sseResponse.writableEnded) { this._sseResponse.end(); } this.cleanupConnection(); return new Promise((resolve) => { if (!this._server) { resolve(); return; } this._server.close(() => { logger.info("SSE server stopped"); this._server = undefined; this._onclose?.(); resolve(); }); }); } cleanupConnection() { if (this._keepAliveInterval) { clearInterval(this._keepAliveInterval); this._keepAliveInterval = undefined; } this._sseResponse = undefined; } isRunning() { return Boolean(this._server); } }