UNPKG

@bratcliffe909/mcp-server-segmind

Version:

Model Context Protocol server for Segmind API - Generate images and videos using AI models

865 lines 43.6 kB
import { Server } from '@modelcontextprotocol/sdk/server/index.js'; import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; import { ListToolsRequestSchema, CallToolRequestSchema, ListResourcesRequestSchema, ReadResourceRequestSchema, ListPromptsRequestSchema, GetPromptRequestSchema, } from '@modelcontextprotocol/sdk/types.js'; import { apiClient } from './api/client.js'; import { modelRegistry, ModelCategory } from './models/registry.js'; import { generateImageTool, transformImageTool, generateVideoTool, generateAudioTool, generateMusicTool, enhanceImageTool, estimateCostTool, readLocalImageTool, prepareImageTool } from './tools/index.js'; import { mapToSafeError, formatErrorResponse } from './utils/errors.js'; import { logger } from './utils/logger.js'; export class SegmindMCPServer { server; state; constructor() { this.state = { isInitialized: false, modelsLoaded: false, activeRequests: 0, rateLimits: new Map(), }; this.server = new Server({ name: '@segmind/mcp-server', version: '0.1.0', }, { capabilities: { tools: {}, resources: {}, prompts: {}, logging: {}, }, }); this.server.oninitialized = () => { void this.handlePostInitialization(); }; this.setupHandlers(); } async handlePostInitialization() { try { await this.sendMCPLog('info', 'Segmind MCP Server initialized'); const hasApiKey = !!process.env.SEGMIND_API_KEY; if (!hasApiKey) { await this.sendMCPLog('warning', 'SEGMIND_API_KEY not found', { help: 'Please set SEGMIND_API_KEY in your environment or MCP configuration', impact: 'API calls will fail without a valid API key', }); } } catch (err) { logger.error('Failed to send post-initialization notifications', { error: err }); } } async sendMCPLog(level, message, data) { try { await this.server.sendLoggingMessage({ level, logger: 'segmind-mcp', data: data ? { message, ...data } : message, }); } catch (err) { logger.error('Failed to send MCP log', { level, message, error: err }); } } setupHandlers() { this.server.setRequestHandler(ListToolsRequestSchema, async () => { logger.info('Listing available tools'); const t2iModels = modelRegistry.getModelsByCategory(ModelCategory.TEXT_TO_IMAGE); const modelIds = t2iModels.map(m => m.id); return { tools: [ { name: 'generate_image', description: 'Generate images from text prompts using various AI models. Returns base64-encoded image data with MIME type information.', inputSchema: { type: 'object', properties: { prompt: { type: 'string', description: 'Text prompt describing the image to generate', minLength: 1, maxLength: 4000, }, model: { type: 'string', description: `Model to use (default: auto-select based on prompt)`, enum: modelIds, }, negative_prompt: { type: 'string', description: 'What to avoid in the generation', }, width: { type: 'number', description: 'Image width (must be multiple of 8)', minimum: 256, maximum: 2048, multipleOf: 8, }, height: { type: 'number', description: 'Image height (must be multiple of 8)', minimum: 256, maximum: 2048, multipleOf: 8, }, num_images: { type: 'number', description: 'Number of images to generate', minimum: 1, maximum: 4, default: 1, }, seed: { type: 'number', description: 'Seed for reproducible generation', }, quality: { type: 'string', description: 'Quality preset', enum: ['draft', 'standard', 'high'], default: 'standard', }, style: { type: 'string', description: 'Style modifier (e.g., "photorealistic", "anime", "oil painting")', }, display_mode: { type: 'string', description: 'How to return the image: display (show image), save (return base64 for saving), both (show image and provide base64)', enum: ['display', 'save', 'both'], default: 'display', }, }, required: ['prompt'], }, }, { name: 'list_models', description: 'List available AI models by category', inputSchema: { type: 'object', properties: { category: { type: 'string', description: 'Filter by category', enum: Object.values(ModelCategory), }, }, }, }, { name: 'get_model_info', description: 'Get detailed information about a specific model', inputSchema: { type: 'object', properties: { model_id: { type: 'string', description: 'Model ID', }, }, required: ['model_id'], }, }, { name: 'transform_image', description: 'Transform existing images using AI with various control methods. Accepts file paths directly (e.g. C:\\photo.jpg), URLs, or base64. File paths are automatically processed without displaying the base64 string.', inputSchema: { type: 'object', properties: { image: { type: 'string', description: 'Input image as: file path (e.g. C:\\photo.jpg or /home/user/image.png), URL, base64 string, or image cache ID. File paths are automatically handled without displaying base64.', }, prompt: { type: 'string', description: 'Transformation prompt describing desired changes', minLength: 1, maxLength: 2000, }, model: { type: 'string', description: 'Model ID to use for transformation', enum: modelRegistry.getModelsByCategory(ModelCategory.IMAGE_TO_IMAGE).map(m => m.id), }, negative_prompt: { type: 'string', description: 'What to avoid in the transformation', }, strength: { type: 'number', description: 'Transformation strength (0=no change, 1=complete change)', minimum: 0, maximum: 1, default: 0.75, }, mask: { type: 'string', description: 'Mask image for inpainting (base64 or URL)', }, control_type: { type: 'string', description: 'ControlNet type', enum: ['canny', 'depth', 'pose', 'scribble', 'segmentation'], }, seed: { type: 'number', description: 'Seed for reproducible generation', }, display_mode: { type: 'string', description: 'How to return the image: display (show image), save (return base64 for saving), both (show image and provide base64)', enum: ['display', 'save', 'both'], default: 'display', }, }, required: ['image', 'prompt'], }, }, { name: 'generate_video', description: 'Generate videos from text prompts or animate static images', inputSchema: { type: 'object', properties: { prompt: { type: 'string', description: 'Text prompt or motion description for video generation', minLength: 1, maxLength: 2000, }, model: { type: 'string', description: 'Model ID to use for video generation', enum: modelRegistry.getModelsByCategory(ModelCategory.TEXT_TO_VIDEO).map(m => m.id), }, image: { type: 'string', description: 'Input image for image-to-video generation (base64 or URL)', }, duration: { type: 'number', description: 'Video duration in seconds', minimum: 1, maximum: 30, default: 5, }, fps: { type: 'number', description: 'Frames per second', minimum: 12, maximum: 60, default: 24, }, aspect_ratio: { type: 'string', description: 'Video aspect ratio', enum: ['16:9', '9:16', '1:1', '4:3'], default: '16:9', }, quality: { type: 'string', description: 'Video quality preset', enum: ['standard', 'high', 'ultra'], default: 'high', }, seed: { type: 'number', description: 'Seed for reproducible generation', }, save_location: { type: 'string', description: 'Directory path to save the video. Overrides default save location.', }, }, required: ['prompt'], }, }, { name: 'enhance_image', description: 'Enhance images with upscaling, restoration, background removal, and more. Accepts file paths directly (e.g. C:\\photo.jpg), URLs, or base64. File paths are automatically processed without displaying the base64 string.', inputSchema: { type: 'object', properties: { image: { type: 'string', description: 'Input image as: file path (e.g. C:\\photo.jpg or /home/user/image.png), URL, base64 string, or image cache ID. File paths are automatically handled without displaying base64.', }, operation: { type: 'string', description: 'Enhancement operation to perform', enum: ['upscale', 'restore', 'remove_background', 'colorize', 'denoise'], }, model: { type: 'string', description: 'Specific model to use for enhancement', enum: modelRegistry.getModelsByCategory(ModelCategory.IMAGE_ENHANCEMENT).map(m => m.id), }, scale: { type: 'string', description: 'Upscaling factor (for upscale operation)', enum: ['2', '4', '8'], default: '4', }, face_enhance: { type: 'boolean', description: 'Enhance faces during upscaling', default: false, }, return_mask: { type: 'boolean', description: 'Return mask for background removal', default: false, }, batch_size: { type: 'number', description: 'Number of images to process', minimum: 1, maximum: 10, default: 1, }, display_mode: { type: 'string', description: 'How to return the image: display (show image), save (return base64 for saving), both (show image and provide base64)', enum: ['display', 'save', 'both'], default: 'display', }, }, required: ['image', 'operation'], }, }, { name: 'generate_audio', description: 'Generate speech audio from text using TTS models', inputSchema: { type: 'object', properties: { text: { type: 'string', description: 'Text to convert to speech', }, model: { type: 'string', description: 'TTS model to use (dia-tts or orpheus-tts)', enum: modelRegistry.getModelsByCategory(ModelCategory.TEXT_TO_AUDIO).map(m => m.id), }, voice: { type: 'string', description: 'Voice selection for TTS (orpheus: tara, dan, josh, emma)', }, temperature: { type: 'number', description: 'Controls randomness/expressiveness (0.1-2.0)', minimum: 0.1, maximum: 2.0, }, top_p: { type: 'number', description: 'Controls word variety (0.1-1.0, higher = rarer words)', minimum: 0.1, maximum: 1.0, }, max_new_tokens: { type: 'number', description: 'Maximum tokens (controls audio length - higher = longer audio)', minimum: 100, maximum: 10000, }, speed_factor: { type: 'number', description: 'Playback speed (0.5-1.5). Default 0.94 = normal speech. Try 0.8 for slower, 1.1 for faster', minimum: 0.5, maximum: 1.5, }, cfg_scale: { type: 'number', description: 'How strictly to follow text (1-5, dia only)', minimum: 1, maximum: 5, }, cfg_filter_top_k: { type: 'number', description: 'Token filtering (10-100, dia only)', minimum: 10, maximum: 100, }, input_audio: { type: 'string', description: 'Base64 audio for voice cloning (dia only)', }, repetition_penalty: { type: 'number', description: 'Penalty for repeated phrases (1.0-2.0, orpheus only)', minimum: 1.0, maximum: 2.0, }, seed: { type: 'number', description: 'Seed for reproducible generation', }, display_mode: { type: 'string', description: 'How to return the audio: display (show audio), save (return base64 for saving), both (show audio and provide base64)', enum: ['display', 'save', 'both'], default: 'display', }, save_location: { type: 'string', description: 'Directory path to save the audio. Overrides default save location.', }, }, required: ['text'], }, }, { name: 'generate_music', description: 'Generate music from text descriptions', inputSchema: { type: 'object', properties: { prompt: { type: 'string', description: 'Text description of the music to generate', }, model: { type: 'string', description: 'Music generation model to use', enum: modelRegistry.getModelsByCategory(ModelCategory.TEXT_TO_MUSIC).map(m => m.id), }, duration: { type: 'number', description: 'Duration in seconds for the music', minimum: 1, maximum: 300, }, negative_prompt: { type: 'string', description: 'What to avoid in the generation', }, seed: { type: 'number', description: 'Seed for reproducible generation', }, num_outputs: { type: 'number', description: 'Number of variations to generate', minimum: 1, maximum: 4, default: 1, }, display_mode: { type: 'string', description: 'How to return the audio: display (show audio), save (return base64 for saving), both (show audio and provide base64)', enum: ['display', 'save', 'both'], default: 'display', }, save_location: { type: 'string', description: 'Directory path to save the music. Overrides default save location.', }, }, required: ['prompt'], }, }, { name: 'estimate_cost', description: 'Estimate the credit cost and time for image/video generation operations', inputSchema: { type: 'object', properties: { operation: { type: 'string', description: 'Type of operation (generate, transform, enhance, etc.)', }, model: { type: 'string', description: 'Model ID to estimate cost for', }, category: { type: 'string', description: 'Model category to list costs for', }, num_images: { type: 'number', description: 'Number of images to generate', minimum: 1, maximum: 10, }, num_outputs: { type: 'number', description: 'Number of outputs to generate', minimum: 1, maximum: 10, }, list_all: { type: 'boolean', description: 'List costs for all available models', }, }, }, }, { name: 'check_credits', description: 'Check remaining API credits', inputSchema: { type: 'object', properties: {}, }, }, { name: 'prepare_image', description: 'RECOMMENDED: Prepare a local image file for use with other tools. Returns a short ID instead of the full base64 string, avoiding display slowdowns. Always use this instead of read_local_image for image transformation tasks.', inputSchema: { type: 'object', properties: { file_path: { type: 'string', description: 'Absolute path to the image file', }, max_size_kb: { type: 'number', description: 'Maximum size in KB before warning (default: 800KB)', default: 800, }, }, required: ['file_path'], }, }, { name: 'read_local_image', description: 'Read a local image file and convert it to base64. WARNING: Returns the full base64 string which can be very large and slow to display. Use prepare_image instead for better performance.', inputSchema: { type: 'object', properties: { file_path: { type: 'string', description: 'Absolute path to the image file', }, return_format: { type: 'string', description: 'Format to return: base64 string or data URI', enum: ['base64', 'data_uri'], default: 'base64', }, }, required: ['file_path'], }, }, ], }; }); this.server.setRequestHandler(CallToolRequestSchema, async (request) => { const { name, arguments: args } = request.params; logger.info('Tool called', { tool: name }); this.state.activeRequests++; try { switch (name) { case 'generate_image': return await generateImageTool.execute(args); case 'transform_image': return await transformImageTool.execute(args); case 'generate_video': return await generateVideoTool.execute(args); case 'enhance_image': return await enhanceImageTool.execute(args); case 'generate_audio': return await generateAudioTool.execute(args); case 'generate_music': return await generateMusicTool.execute(args); case 'prepare_image': return await prepareImageTool.execute(args); case 'read_local_image': return await readLocalImageTool.execute(args); case 'list_models': { const category = args?.category; const models = category ? modelRegistry.getModelsByCategory(category) : modelRegistry.getAllModels(); const modelList = models.map(m => ({ id: m.id, name: m.name, description: m.description, category: m.category, creditsPerUse: m.creditsPerUse, estimatedTime: m.estimatedTime, })); return { content: [ { type: 'text', text: JSON.stringify(modelList, null, 2), }, ], }; } case 'get_model_info': { const modelId = args?.model_id; if (!modelId) { throw new Error('model_id is required'); } const model = modelRegistry.getModel(modelId); if (!model) { throw new Error(`Model ${modelId} not found`); } const info = { id: model.id, name: model.name, description: model.description, category: model.category, endpoint: model.endpoint, apiVersion: model.apiVersion, outputType: model.outputType, estimatedTime: model.estimatedTime, creditsPerUse: model.creditsPerUse, supportedFormats: model.supportedFormats, maxDimensions: model.maxDimensions, defaultParams: model.defaultParams, parameters: Object.keys(model.parameters.shape), }; return { content: [ { type: 'text', text: JSON.stringify(info, null, 2), }, ], }; } case 'estimate_cost': return await estimateCostTool.execute(args); case 'check_credits': { const credits = await apiClient.getCredits(); return { content: [ { type: 'text', text: `API Credits:\nRemaining: ${credits.remaining}\nUsed: ${credits.used}`, }, ], }; } default: throw new Error(`Unknown tool: ${name}`); } } catch (error) { const safeError = mapToSafeError(error); logger.error('Tool execution failed', formatErrorResponse(safeError)); await this.sendMCPLog('error', `Tool execution failed: ${name}`, { tool: name, error: safeError.userMessage, code: safeError.code, }); throw safeError; } finally { this.state.activeRequests--; } }); this.server.setRequestHandler(ListResourcesRequestSchema, async () => { logger.info('Listing available resources'); const categoryResources = Object.values(ModelCategory).map(category => ({ uri: `segmind://models/${category}`, name: `${category.replace('_', ' ').toUpperCase()} Models`, description: `List models in the ${category} category`, mimeType: 'application/json', })); return { resources: [ { uri: 'segmind://models', name: 'All Available Models', description: 'List all available Segmind models', mimeType: 'application/json', }, ...categoryResources, { uri: 'segmind://credits', name: 'API Credits', description: 'Check remaining API credits', mimeType: 'application/json', }, ], }; }); this.server.setRequestHandler(ReadResourceRequestSchema, async (request) => { const { uri } = request.params; logger.info('Reading resource', { uri }); try { if (uri === 'segmind://models') { const models = modelRegistry.getAllModels(); const modelData = models.map(m => ({ id: m.id, name: m.name, description: m.description, category: m.category, creditsPerUse: m.creditsPerUse, estimatedTime: m.estimatedTime, outputType: m.outputType, supportedFormats: m.supportedFormats, })); return { contents: [ { uri, mimeType: 'application/json', text: JSON.stringify({ totalModels: models.length, models: modelData, }, null, 2), }, ], }; } const categoryMatch = uri.match(/^segmind:\/\/models\/(.+)$/); if (categoryMatch) { const category = categoryMatch[1]; const models = modelRegistry.getModelsByCategory(category); if (models.length === 0) { throw new Error(`Invalid category: ${category}`); } const modelData = models.map(m => ({ id: m.id, name: m.name, description: m.description, creditsPerUse: m.creditsPerUse, estimatedTime: m.estimatedTime, outputType: m.outputType, supportedFormats: m.supportedFormats, defaultParams: m.defaultParams, })); return { contents: [ { uri, mimeType: 'application/json', text: JSON.stringify({ category, totalModels: models.length, models: modelData, }, null, 2), }, ], }; } if (uri === 'segmind://credits') { const credits = await apiClient.getCredits(); return { contents: [ { uri, mimeType: 'application/json', text: JSON.stringify({ credits: { remaining: credits.remaining, used: credits.used, }, lastUpdated: new Date().toISOString(), }, null, 2), }, ], }; } throw new Error(`Unknown resource: ${uri}`); } catch (error) { const safeError = mapToSafeError(error); logger.error('Resource read failed', formatErrorResponse(safeError)); throw safeError; } }); this.server.setRequestHandler(ListPromptsRequestSchema, async () => { logger.info('Listing available prompts'); return { prompts: [ { name: 'art_styles', description: 'Generate images in specific art styles', arguments: [ { name: 'style', description: 'Art style (e.g., impressionist, anime, photorealistic)', required: true, }, { name: 'subject', description: 'What to depict', required: true, }, ], }, ], }; }); this.server.setRequestHandler(GetPromptRequestSchema, async (request) => { const { name, arguments: args } = request.params; logger.info('Getting prompt', { name }); try { switch (name) { case 'art_styles': const style = args?.['style'] || 'photorealistic'; const subject = args?.['subject'] || 'landscape'; return { description: `Generate ${subject} in ${style} style`, messages: [ { role: 'user', content: { type: 'text', text: `Create a ${style} artwork depicting: ${subject}. Include rich details and appropriate artistic techniques for the ${style} style.`, }, }, ], }; default: throw new Error(`Unknown prompt: ${name}`); } } catch (error) { const safeError = mapToSafeError(error); logger.error('Prompt generation failed', formatErrorResponse(safeError)); throw safeError; } }); } async start() { try { const transport = new StdioServerTransport(); transport.onclose = () => { logger.info('Transport closed, shutting down server'); process.exit(0); }; transport.onerror = (error) => { logger.error('Transport error', { error }); process.exit(1); }; this.server.onerror = (error) => { logger.error('MCP Server Error', { error }); void this.sendMCPLog('error', 'MCP Server Error', { error: error.message || error }); }; await this.server.connect(transport); this.state.isInitialized = true; logger.info('Segmind MCP Server started successfully'); } catch (error) { logger.error('Failed to start server', { error }); throw error; } } async shutdown() { logger.info('Shutting down Segmind MCP Server'); const timeout = 30000; const start = Date.now(); while (this.state.activeRequests > 0 && Date.now() - start < timeout) { await new Promise(resolve => setTimeout(resolve, 100)); } if (this.state.activeRequests > 0) { logger.warn('Forcing shutdown with active requests', { activeRequests: this.state.activeRequests, }); } await this.server.close(); logger.info('Server shutdown complete'); } } export const server = new SegmindMCPServer(); //# sourceMappingURL=server.js.map