UNPKG

@ithena-one/mcp-governance

Version:

Governance layer (Identity, RBAC, Credentials, Audit, Logging, Tracing) for Model Context Protocol (MCP) servers.

274 lines 14.7 kB
/* eslint-disable @typescript-eslint/no-explicit-any */ // src/core/governed-server.ts import { McpError, ErrorCode as McpErrorCode, } from '@modelcontextprotocol/sdk/types.js'; import { z } from 'zod'; import { GovernancePipeline } from './governance-pipeline.js'; // Import the new class import { LifecycleManager } from './lifecycle-manager.js'; // Import the new class import { mapErrorToPayload } from '../utils/error-mapper.js'; import { generateEventId, buildTransportContext } from '../utils/helpers.js'; import { defaultLogger } from '../defaults/logger.js'; import { defaultAuditStore } from '../defaults/audit.js'; import { defaultTraceContextProvider } from '../defaults/tracing.js'; import { defaultDerivePermission } from '../defaults/permissions.js'; import { defaultSanitizeForAudit } from '../defaults/sanitization.js'; /** * Wraps a base Model Context Protocol (MCP) Server to add a governance layer. */ export class GovernedServer { constructor(baseServer, options = {}) { this.requestHandlers = new Map(); this.notificationHandlers = new Map(); this.baseServer = baseServer; this.options = { identityResolver: options.identityResolver, roleStore: options.roleStore, permissionStore: options.permissionStore, credentialResolver: options.credentialResolver, postAuthorizationHook: options.postAuthorizationHook, serviceIdentifier: options.serviceIdentifier, auditStore: options.auditStore ?? defaultAuditStore, logger: options.logger ?? defaultLogger, traceContextProvider: options.traceContextProvider ?? defaultTraceContextProvider, enableRbac: options.enableRbac ?? false, failOnCredentialResolutionError: options.failOnCredentialResolutionError ?? true, auditDeniedRequests: options.auditDeniedRequests ?? true, auditNotifications: options.auditNotifications ?? false, derivePermission: options.derivePermission ?? defaultDerivePermission, sanitizeForAudit: options.sanitizeForAudit ?? defaultSanitizeForAudit, }; if (this.options.enableRbac && (!this.options.roleStore || !this.options.permissionStore)) { throw new Error("RoleStore and PermissionStore must be provided when RBAC is enabled."); } // Initialize LifecycleManager this.lifecycleManager = new LifecycleManager(this.options.logger, [ this.options.logger, this.options.auditStore, this.options.identityResolver, this.options.roleStore, this.options.permissionStore, this.options.credentialResolver, ]); } get transport() { return this.transportInternal; } async connect(transport) { if (this.transportInternal) { throw new Error("GovernedServer is already connected."); } const logger = this.options.logger; logger.info("GovernedServer connecting..."); this.transportInternal = transport; try { // --- Initialize Components --- await this.lifecycleManager.initialize(); // --- Instantiate Pipeline --- // Pass necessary options and handler maps to the pipeline instance this.pipeline = new GovernancePipeline(this.options, this.requestHandlers, this.notificationHandlers); // --- Register Base Handlers --- this._registerBaseHandlers(); // --- Connect Base Server --- await this.baseServer.connect(transport); // --- Setup Governed Close Handling --- const originalBaseOnClose = this.baseServer.onclose; this.baseServer.onclose = () => { Promise.resolve().then(async () => { logger.info("Base server connection closed, running governed cleanup..."); await this.lifecycleManager.shutdown(); // Use manager for shutdown }).catch(err => { logger.error("Error during component shutdown on close", err); }).finally(() => { this.transportInternal = undefined; this.pipeline = undefined; // Clear pipeline instance originalBaseOnClose?.(); logger.debug("Governed onclose handler finished."); }); }; logger.info("GovernedServer connected successfully."); } catch (error) { logger.error("GovernedServer connection failed during initialization", error); await this.lifecycleManager.shutdown(); // Attempt cleanup on failure this.transportInternal = undefined; this.pipeline = undefined; throw error; } } async close() { const logger = this.options.logger; if (!this.transportInternal) { logger.info("GovernedServer close called, but already closed or not connected."); return; } logger.info("GovernedServer closing..."); // Shutdown components first using the manager await this.lifecycleManager.shutdown(); // Then close the base server (which should trigger our onclose handler) if (this.baseServer) { try { await this.baseServer.close(); } catch (err) { logger.error("Error during baseServer.close()", err); // Ensure state is cleared anyway this.transportInternal = undefined; this.pipeline = undefined; } } else { this.transportInternal = undefined; this.pipeline = undefined; } logger.info("GovernedServer closed."); } async notification(notification) { await this.baseServer.notification(notification); } // --- Handler Registration (remains the same, stores locally) --- setRequestHandler(requestSchema, handler) { const method = requestSchema.shape.method.value; if (this.transportInternal) { throw new Error(`Cannot register request handler for ${method} after connect() has been called.`); } if (this.requestHandlers.has(method)) { this.options.logger.warn(`Overwriting request handler for method: ${method}`); } this.requestHandlers.set(method, { handler: handler, schema: requestSchema }); this.options.logger.debug(`Stored governed request handler for: ${method}`); } setNotificationHandler(notificationSchema, handler) { const method = notificationSchema.shape.method.value; if (this.transportInternal) { throw new Error(`Cannot register notification handler for ${method} after connect() has been called.`); } if (this.notificationHandlers.has(method)) { this.options.logger.warn(`Overwriting notification handler for method: ${method}`); } this.notificationHandlers.set(method, { handler: handler, schema: notificationSchema }); this.options.logger.debug(`Stored governed notification handler for: ${method}`); } // --- Wrapper Handler Creation and Registration --- /** Registers wrapper functions with the baseServer for all stored handlers. */ /** Registers wrapper functions with the baseServer for all stored handlers. */ _registerBaseHandlers() { this.options.logger.debug("Registering base server handlers for governed methods..."); // Define a base schema that allows optional params // WORKAROUND: Registering with a schema that explicitly includes `params: z.any().optional()` // appears necessary to prevent the current version of the base SDK Server // from stripping the params object before calling this wrapper handler. // This is related to an upstream issue/PR: https://github.com/modelcontextprotocol/typescript-sdk/pull/248 // This workaround should be removed once the upstream fix is incorporated. const baseMethodSchema = (method) => z.object({ jsonrpc: z.literal("2.0").optional(), // Allow flexibility from base SDK parsing id: z.union([z.string(), z.number()]).optional(), // Allow flexibility method: z.literal(method), params: z.any().optional() // <-- Explicitly allow optional params of any type }).passthrough(); // Allow other fields like _meta this.requestHandlers.forEach((_handlerInfo, method) => { const handler = this._createPipelineRequestHandler(method); const schemaForBaseServer = baseMethodSchema(method); // Register with the base server using the more permissive schema this.baseServer.setRequestHandler(schemaForBaseServer, handler); this.options.logger.debug(`Registered base request handler for: ${method}`); }); this.notificationHandlers.forEach((_handlerInfo, method) => { const handler = this._createPipelineNotificationHandler(method); // Notifications also might have params, allow them minimally const notificationSchemaForBaseServer = z.object({ jsonrpc: z.literal("2.0").optional(), method: z.literal(method), params: z.any().optional() }).passthrough(); this.baseServer.setNotificationHandler(notificationSchemaForBaseServer, handler); this.options.logger.debug(`Registered base notification handler for: ${method}`); }); this.options.logger.debug("Base handler registration complete."); } /** Creates the wrapper that calls the request pipeline. */ _createPipelineRequestHandler(method) { return async (request, baseExtra) => { if (!this.pipeline) { this.options.logger.error(`Request received for ${method} but pipeline is not initialized. Server not connected?`); throw new McpError(McpErrorCode.InternalError, "GovernedServer pipeline not initialized."); } // --- Prepare Initial Context for Pipeline --- const eventId = generateEventId(); const startTime = Date.now(); const transportContext = buildTransportContext(this.transportInternal); const traceContext = this.options.traceContextProvider(transportContext, request); const baseLogger = this.options.logger; const requestLogger = baseLogger.child ? baseLogger.child({ eventId, requestId: request.id, method: request.method, ...(traceContext?.traceId && { traceId: traceContext.traceId }), ...(traceContext?.spanId && { spanId: traceContext.spanId }), ...(transportContext.sessionId && { sessionId: transportContext.sessionId }), }) : baseLogger; const operationContext = { eventId, timestamp: new Date(startTime), transportContext, traceContext, logger: requestLogger, mcpMessage: request, serviceIdentifier: this.options.serviceIdentifier, }; const auditRecord = { eventId, timestamp: new Date(startTime).toISOString(), serviceIdentifier: this.options.serviceIdentifier, transport: transportContext, mcp: { type: "request", method: request.method, id: request.id }, trace: traceContext, identity: null, }; // --- Execute Pipeline --- try { requestLogger.debug(`Pipeline request handler invoked for: ${method}`); // Delegate actual execution to the pipeline instance return await this.pipeline.executeRequestPipeline(request, baseExtra, operationContext, auditRecord); } catch (error) { // Catch errors from the pipeline execution itself and map for baseServer requestLogger.error(`Unhandled error in request pipeline execution for ${method}`, error); const payload = mapErrorToPayload(error, McpErrorCode.InternalError, "Internal governance pipeline error"); throw new McpError(payload.code, payload.message, payload.data); } }; } /** Creates the wrapper that calls the notification pipeline. */ _createPipelineNotificationHandler(method) { return async (notification, baseExtra) => { if (!this.pipeline) { this.options.logger.error(`Notification received for ${method} but pipeline is not initialized. Server not connected?`); // Don't throw for notifications, just log return; } // --- Prepare Initial Context --- const eventId = generateEventId(); const startTime = Date.now(); const transportContext = buildTransportContext(this.transportInternal); const traceContext = this.options.traceContextProvider(transportContext, notification); const baseLogger = this.options.logger; const notificationLogger = baseLogger.child ? baseLogger.child({ /* ... context ... */}) : baseLogger; const operationContext = { eventId, timestamp: new Date(startTime), transportContext, traceContext, logger: notificationLogger, mcpMessage: notification, serviceIdentifier: this.options.serviceIdentifier, }; const auditRecord = { eventId, timestamp: new Date(startTime).toISOString(), serviceIdentifier: this.options.serviceIdentifier, transport: transportContext, mcp: { type: "notification", method: notification.method }, trace: traceContext, identity: null, }; // --- Execute Pipeline --- try { notificationLogger.debug(`Pipeline notification handler invoked for: ${method}`); await this.pipeline.executeNotificationPipeline(notification, baseExtra, operationContext, auditRecord); } catch (error) { // Log pipeline errors, but don't throw notificationLogger.error(`Unhandled error in notification pipeline execution for ${method}`, error); } }; } } // End GovernedServer class //# sourceMappingURL=governed-server.js.map