mcp-framework
Version:
Framework for building Model Context Protocol (MCP) servers in Typescript
322 lines (321 loc) • 13.2 kB
JavaScript
import { randomUUID } from "node:crypto";
import { createServer } from "node:http";
import contentType from "content-type";
import getRawBody from "raw-body";
import { APIKeyAuthProvider } from "../../auth/providers/apikey.js";
import { DEFAULT_AUTH_ERROR } from "../../auth/types.js";
import { AbstractTransport } from "../base.js";
import { DEFAULT_SSE_CONFIG, DEFAULT_CORS_CONFIG } from "./types.js";
import { logger } from "../../core/Logger.js";
import { getRequestHeader, setResponseHeaders } from "../../utils/headers.js";
import { PING_SSE_MESSAGE } from "../utils/ping-message.js";
const SSE_HEADERS = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
};
export class SSEServerTransport extends AbstractTransport {
type = "sse";
_server;
_sseResponse;
_sessionId;
_config;
_keepAliveInterval;
constructor(config = {}) {
super();
this._sessionId = randomUUID();
this._config = {
...DEFAULT_SSE_CONFIG,
...config
};
logger.debug(`SSE transport configured with: ${JSON.stringify({
...this._config,
auth: this._config.auth ? {
provider: this._config.auth.provider.constructor.name,
endpoints: this._config.auth.endpoints
} : undefined
})}`);
}
getCorsHeaders(includeMaxAge = false) {
const corsConfig = {
allowOrigin: DEFAULT_CORS_CONFIG.allowOrigin,
allowMethods: DEFAULT_CORS_CONFIG.allowMethods,
allowHeaders: DEFAULT_CORS_CONFIG.allowHeaders,
exposeHeaders: DEFAULT_CORS_CONFIG.exposeHeaders,
maxAge: DEFAULT_CORS_CONFIG.maxAge,
...this._config.cors
};
const headers = {
"Access-Control-Allow-Origin": corsConfig.allowOrigin,
"Access-Control-Allow-Methods": corsConfig.allowMethods,
"Access-Control-Allow-Headers": corsConfig.allowHeaders,
"Access-Control-Expose-Headers": corsConfig.exposeHeaders
};
if (includeMaxAge) {
headers["Access-Control-Max-Age"] = corsConfig.maxAge;
}
return headers;
}
async start() {
if (this._server) {
throw new Error("SSE transport already started");
}
return new Promise((resolve) => {
this._server = createServer(async (req, res) => {
try {
await this.handleRequest(req, res);
}
catch (error) {
logger.error(`Error handling request: ${error}`);
res.writeHead(500).end("Internal Server Error");
}
});
this._server.listen(this._config.port, () => {
logger.info(`SSE transport listening on port ${this._config.port}`);
resolve();
});
this._server.on("error", (error) => {
logger.error(`SSE server error: ${error}`);
this._onerror?.(error);
});
this._server.on("close", () => {
logger.info("SSE server closed");
this._onclose?.();
});
});
}
async handleRequest(req, res) {
logger.debug(`Incoming request: ${req.method} ${req.url}`);
if (req.method === "OPTIONS") {
setResponseHeaders(res, this.getCorsHeaders(true));
res.writeHead(204).end();
return;
}
setResponseHeaders(res, this.getCorsHeaders());
const url = new URL(req.url, `http://${req.headers.host}`);
const sessionId = url.searchParams.get("sessionId");
if (req.method === "GET" && url.pathname === this._config.endpoint) {
if (this._config.auth?.endpoints?.sse) {
const isAuthenticated = await this.handleAuthentication(req, res, "SSE connection");
if (!isAuthenticated)
return;
}
if (this._sseResponse?.writableEnded) {
this._sseResponse = undefined;
}
if (this._sseResponse) {
logger.warn("SSE connection already established; closing the old connection to allow a new one.");
this._sseResponse.end();
this.cleanupConnection();
}
this.setupSSEConnection(res);
return;
}
if (req.method === "POST" && url.pathname === this._config.messageEndpoint) {
if (sessionId !== this._sessionId) {
logger.warn(`Invalid session ID received: ${sessionId}, expected: ${this._sessionId}`);
res.writeHead(403).end("Invalid session ID");
return;
}
if (this._config.auth?.endpoints?.messages !== false) {
const isAuthenticated = await this.handleAuthentication(req, res, "message");
if (!isAuthenticated)
return;
}
await this.handlePostMessage(req, res);
return;
}
res.writeHead(404).end("Not Found");
}
async handleAuthentication(req, res, context) {
if (!this._config.auth?.provider) {
return true;
}
const isApiKey = this._config.auth.provider instanceof APIKeyAuthProvider;
if (isApiKey) {
const provider = this._config.auth.provider;
const headerValue = getRequestHeader(req.headers, provider.getHeaderName());
if (!headerValue) {
const error = provider.getAuthError?.() || DEFAULT_AUTH_ERROR;
res.setHeader("WWW-Authenticate", `ApiKey realm="MCP Server", header="${provider.getHeaderName()}"`);
res.writeHead(error.status).end(JSON.stringify({
error: error.message,
status: error.status,
type: "authentication_error"
}));
return false;
}
}
const authResult = await this._config.auth.provider.authenticate(req);
if (!authResult) {
const error = this._config.auth.provider.getAuthError?.() || DEFAULT_AUTH_ERROR;
logger.warn(`Authentication failed for ${context}:`);
logger.warn(`- Client IP: ${req.socket.remoteAddress}`);
logger.warn(`- Error: ${error.message}`);
if (isApiKey) {
const provider = this._config.auth.provider;
res.setHeader("WWW-Authenticate", `ApiKey realm="MCP Server", header="${provider.getHeaderName()}"`);
}
res.writeHead(error.status).end(JSON.stringify({
error: error.message,
status: error.status,
type: "authentication_error"
}));
return false;
}
logger.info(`Authentication successful for ${context}:`);
logger.info(`- Client IP: ${req.socket.remoteAddress}`);
logger.info(`- Auth Type: ${this._config.auth.provider.constructor.name}`);
return true;
}
setupSSEConnection(res) {
logger.debug(`Setting up SSE connection for session: ${this._sessionId}`);
const headers = {
...SSE_HEADERS,
...this.getCorsHeaders(),
...this._config.headers
};
setResponseHeaders(res, headers);
logger.debug(`SSE headers set: ${JSON.stringify(headers)}`);
if (res.socket) {
res.socket.setNoDelay(true);
res.socket.setTimeout(0);
res.socket.setKeepAlive(true, 1000);
logger.debug('Socket optimized for SSE connection');
}
const endpointUrl = `${this._config.messageEndpoint}?sessionId=${this._sessionId}`;
logger.debug(`Sending endpoint URL: ${endpointUrl}`);
res.write(`event: endpoint\ndata: ${endpointUrl}\n\n`);
logger.debug('Sending initial keep-alive');
this._keepAliveInterval = setInterval(() => {
if (this._sseResponse && !this._sseResponse.writableEnded) {
try {
this._sseResponse.write(PING_SSE_MESSAGE);
}
catch (error) {
logger.error(`Error sending keep-alive: ${error}`);
this.cleanupConnection();
}
}
}, 15000);
this._sseResponse = res;
const cleanup = () => this.cleanupConnection();
res.on("close", () => {
logger.info(`SSE connection closed for session: ${this._sessionId}`);
cleanup();
});
res.on("error", (error) => {
logger.error(`SSE connection error for session ${this._sessionId}: ${error}`);
this._onerror?.(error);
cleanup();
});
res.on("end", () => {
logger.info(`SSE connection ended for session: ${this._sessionId}`);
cleanup();
});
logger.info(`SSE connection established successfully for session: ${this._sessionId}`);
}
async handlePostMessage(req, res) {
if (!this._sseResponse || this._sseResponse.writableEnded) {
logger.warn(`Rejecting message: no active SSE connection for session ${this._sessionId}`);
res.writeHead(409).end("SSE connection not established");
return;
}
let currentMessage = {};
try {
const rawMessage = req.body || await (async () => {
const ct = contentType.parse(req.headers["content-type"] ?? "");
if (ct.type !== "application/json") {
throw new Error(`Unsupported content-type: ${ct.type}`);
}
const rawBody = await getRawBody(req, {
limit: this._config.maxMessageSize,
encoding: ct.parameters.charset ?? "utf-8"
});
const parsed = JSON.parse(rawBody.toString());
logger.debug(`Received message: ${JSON.stringify(parsed)}`);
return parsed;
})();
const { id, method, params } = rawMessage;
logger.debug(`Parsed message - ID: ${id}, Method: ${method}`);
const rpcMessage = {
jsonrpc: "2.0",
id: id,
method: method,
params: params
};
currentMessage = {
id: id,
method: method
};
logger.debug(`Processing RPC message: ${JSON.stringify({
id: id,
method: method,
params: params
})}`);
if (!this._onmessage) {
throw new Error("No message handler registered");
}
await this._onmessage(rpcMessage);
res.writeHead(202).end("Accepted");
logger.debug(`Successfully processed message ${rpcMessage.id}`);
}
catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
logger.error(`Error handling message for session ${this._sessionId}:`);
logger.error(`- Error: ${errorMessage}`);
logger.error(`- Method: ${currentMessage.method || "unknown"}`);
logger.error(`- Message ID: ${currentMessage.id || "unknown"}`);
const errorResponse = {
jsonrpc: "2.0",
id: currentMessage.id || null,
error: {
code: -32000,
message: errorMessage,
data: {
method: currentMessage.method || "unknown",
sessionId: this._sessionId,
connectionActive: Boolean(this._sseResponse),
type: "message_handler_error"
}
}
};
res.writeHead(400).end(JSON.stringify(errorResponse));
this._onerror?.(error);
}
}
async send(message) {
if (!this._sseResponse || this._sseResponse.writableEnded) {
throw new Error("SSE connection not established");
}
this._sseResponse.write(`data: ${JSON.stringify(message)}\n\n`);
}
async close() {
if (this._sseResponse && !this._sseResponse.writableEnded) {
this._sseResponse.end();
}
this.cleanupConnection();
return new Promise((resolve) => {
if (!this._server) {
resolve();
return;
}
this._server.close(() => {
logger.info("SSE server stopped");
this._server = undefined;
this._onclose?.();
resolve();
});
});
}
cleanupConnection() {
if (this._keepAliveInterval) {
clearInterval(this._keepAliveInterval);
this._keepAliveInterval = undefined;
}
this._sseResponse = undefined;
}
isRunning() {
return Boolean(this._server);
}
}