UNPKG

@ldavis9000aws/swarmui-generator

Version:

A Model Context Protocol server for SwarmUI image generation with TypeScript

253 lines 14.1 kB
/** * MCP Server Implementation * This module implements the MCP server using the low-level Server API for explicit tool definition, * aligning with working Amazon Bedrock MCP server examples. */ import { Server } from "@modelcontextprotocol/sdk/server/index.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { ListToolsRequestSchema, CallToolRequestSchema, McpError, ErrorCode } from "@modelcontextprotocol/sdk/types.js"; import { z } from "zod"; import { zodToJsonSchema } from 'zod-to-json-schema'; import { SwarmUIClient } from '../client/swarmui-client.js'; import { getServerConfig } from './server-config.js'; import { generateImagesSchema as generateImagesToolDetails } from '../tools/generate-images.js'; import { listModelsSchema as listModelsToolDetails } from '../tools/list-models.js'; import { listSchedulersSchema as listSchedulersToolDetails } from '../tools/list-schedulers.js'; import { systemStatusSchema as systemStatusToolDetails } from '../tools/system-status.js'; const generateImagesParamsZodSchema = z.object({ prompt: z.string().describe('The main textual description of the image to be generated.'), negative_prompt: z.string().optional().describe('Optional. Elements to avoid in the generated image.'), width: z.number().int().min(256).max(4096).optional().describe('Optional. Image width in pixels (default 1024). Divisible by 16. Aspect ratio with height between 1:4 and 4:1. Total pixels < 4,194,304.'), height: z.number().int().min(256).max(4096).optional().describe('Optional. Image height in pixels (default 1024). Divisible by 16. Aspect ratio with width between 1:4 and 4:1. Total pixels < 4,194,304.'), guidance_scale: z.number().min(1.0).max(30.0).optional().describe('Optional. How closely to follow the prompt (1.0-30.0, default 6.5).'), num_inference_steps: z.number().int().min(1).max(150).optional().describe('Optional. Denoising steps count (1-150, default might be 20-30).'), num_images: z.number().int().min(1).max(10).optional().describe('Optional. Number of images to generate (1-10, default 1).'), seed: z.number().int().min(0).max(4294967295).optional().describe('Optional. Random seed for reproducibility (0-4294967295, default varies).'), model_name: z.string().optional().describe('Optional. Specific model to use for generation.'), scheduler: z.string().optional().describe('Optional. Scheduler/sampler algorithm to use.') }); const listModelsParamsZodSchema = z.object({ path: z.string().optional().describe('Optional. The specific folder path within the SwarmUI model directory to search.'), depth: z.number().int().optional().describe('Optional. Maximum depth of subfolders to search recursively.'), subtype: z.string().optional().describe('Optional. Filter by a specific model subtype.'), sortBy: z.string().optional().describe('Optional. Criterion to sort models: "Name" or "Date".'), sortReverse: z.boolean().optional().describe('Optional. If true, reverses sort order.'), allowRemote: z.boolean().optional().describe('Optional. If true, includes models from remote backends.') }); async function handleGenerateImages(args, swarmui) { const { prompt, negative_prompt, width, height, guidance_scale, num_inference_steps, seed, model_name, scheduler, num_images } = args; try { const result = await swarmui.generateImage(prompt, { negativePrompt: negative_prompt, width, height, cfgScale: guidance_scale, steps: num_inference_steps, seed, model: model_name, scheduler, batchSize: num_images }); if (!result.images || result.images.length === 0) { return { content: [{ type: "text", text: 'No images returned from the API.' }], isError: false }; } const resultContent = []; resultContent.push({ type: "text", text: `Generated ${result.images.length} image(s) for prompt: "${prompt}"` }); for (const imagePathOrData of result.images) { try { let base64Data; let mimeType = "image/jpeg"; if (imagePathOrData.startsWith('data:')) { const parts = imagePathOrData.match(/^data:(image\/\w+);base64,(.+)$/); if (!parts) { console.error(`Invalid data URL format for ${imagePathOrData}`); continue; } mimeType = parts[1]; base64Data = parts[2]; } else { base64Data = await swarmui.downloadAndEncodeImage(imagePathOrData); if (imagePathOrData.toLowerCase().endsWith('.png')) mimeType = "image/png"; } resultContent.push({ type: "image", data: base64Data, mimeType: mimeType }); } catch (imageProcessingError) { console.error(`Error processing image ${imagePathOrData}:`, imageProcessingError); resultContent.push({ type: "text", text: `Error processing image ${imagePathOrData}: ${imageProcessingError instanceof Error ? imageProcessingError.message : String(imageProcessingError)}` }); } } return { content: resultContent, isError: false }; } catch (error) { console.error('Error in handleGenerateImages:', error); return { content: [{ type: "text", text: `Error generating images: ${error instanceof Error ? error.message : String(error)}` }], isError: true }; } } async function handleListModels(args, swarmui) { const { path, depth, subtype, sortBy, sortReverse, allowRemote } = args; try { const response = await swarmui.listModels({ path, depth, subtype, sortBy, sortReverse, allowRemote }); const models = response.files; const folders = response.folders; let resultText = ""; const searchPath = path || "root model directory"; const searchDepth = depth === undefined ? 1 : depth; resultText += `Listing for path: "${searchPath}" (depth: ${searchDepth})\n`; if (subtype) resultText += `Subtype filter: "${subtype}"\n`; resultText += "\n"; if (folders.length > 0) { resultText += `Available Folders (${folders.length}):\n`; folders.forEach(folder => { resultText += `- [Folder] ${folder}\n`; }); resultText += "\n"; } else { resultText += "No sub-folders found matching the criteria.\n\n"; } if (models.length > 0) { resultText += `Available Models (${models.length}):\n\n`; models.forEach(model => { resultText += `- Name: ${model.Name}\n`; resultText += ` Display Name: ${model.FormattedName || model.Name}\n`; if (model.Description) resultText += ` Description: ${model.Description}\n`; resultText += ` Default Dims: ${model.DefaultWidth || 'N/A'}x${model.DefaultHeight || 'N/A'}\n`; resultText += ` Status: ${model.IsActive ? 'Active' : 'Inactive'}\n`; if (model.Path) { resultText += ` File Path: ${model.Path}\n`; } resultText += `\n`; }); } else { resultText += "No model files found matching the criteria in this path.\n"; } return { content: [{ type: 'text', text: resultText.trim() }], isError: false }; } catch (error) { console.error('Error in handleListModels:', error); return { content: [{ type: 'text', text: `Error listing models: ${error instanceof Error ? error.message : String(error)}` }], isError: true }; } } async function handleListSchedulers(swarmui) { try { const schedulers = await swarmui.listSchedulers(); if (schedulers.length === 0) { return { content: [{ type: 'text', text: 'No schedulers are currently available on the SwarmUI server.' }], isError: false }; } let resultText = `Available Schedulers/Samplers (${schedulers.length}):\n\n`; schedulers.forEach(scheduler => { resultText += `- Name: ${scheduler.Name}\n`; resultText += ` Display Name: ${scheduler.FormattedName || scheduler.Name}\n`; if (scheduler.Description) resultText += ` Description: ${scheduler.Description}\n`; resultText += `\n`; }); return { content: [{ type: 'text', text: resultText.trim() }], isError: false }; } catch (error) { console.error('Error in handleListSchedulers:', error); return { content: [{ type: 'text', text: `Error listing schedulers: ${error instanceof Error ? error.message : String(error)}` }], isError: true }; } } async function handleGetSystemStatus(swarmui) { try { const status = await swarmui.getEngineStatus(); let resultText = 'SwarmUI System Status:\n'; resultText += '========================\n\n'; resultText += `Model Load State: ${status.ModelLoadState || 'Not specified / Idle'}\n`; resultText += `Current Model Loaded: ${status.ModelName || 'None'}\n`; resultText += `Currently Generating Image: ${status.IsGenerating !== undefined ? (status.IsGenerating ? 'Yes' : 'No') : 'Unknown'}\n`; if (status.IsGenerating && status.CurrentPrompt) { resultText += `Current Prompt Being Processed: "${status.CurrentPrompt}"\n`; } if (status.IsGenerating && status.Progress !== undefined) { resultText += `Current Job Progress: ${Math.round(status.Progress * 100)}%\n`; } if (status.IsGenerating && status.ETA !== undefined) { resultText += `Estimated Time Remaining (ETA) for Current Job: ${status.ETA} seconds\n`; } resultText += `Jobs in Queue: ${status.QueueLength !== undefined ? status.QueueLength : 'Unknown'}\n`; resultText += `Running Backend Services: ${(status.RunningBackends && status.RunningBackends.length > 0) ? status.RunningBackends.join(", ") : 'None reported or not specified'}\n`; return { content: [{ type: 'text', text: resultText.trim() }], isError: false }; } catch (error) { console.error('Error in handleGetSystemStatus:', error); return { content: [{ type: 'text', text: `Error getting system status: ${error instanceof Error ? error.message : String(error)}` }], isError: true }; } } export function startServer(apiUrl) { console.error('Initializing SwarmUI MCP Server...'); // Changed const swarmui = new SwarmUIClient(apiUrl); const serverConfig = getServerConfig(); const server = new Server({ name: serverConfig.name, version: serverConfig.version, }, { capabilities: { tools: {}, }, }); console.error(`SwarmUI MCP Server v${serverConfig.version} starting (using low-level Server API)`); // Changed console.error(`Connecting to SwarmUI API at: ${apiUrl}`); // Changed server.setRequestHandler(ListToolsRequestSchema, async (request, extra) => { const tools = [ { name: 'generate_images', description: generateImagesToolDetails.description, inputSchema: zodToJsonSchema(generateImagesParamsZodSchema) }, { name: 'list_models', description: listModelsToolDetails.description, inputSchema: zodToJsonSchema(listModelsParamsZodSchema) }, { name: 'list_schedulers', description: listSchedulersToolDetails.description, inputSchema: { type: 'object', properties: {} } }, { name: 'get_system_status', description: systemStatusToolDetails.description, inputSchema: { type: 'object', properties: {} } } ]; return { tools }; }); server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { const toolName = request.params.name; const toolArgs = request.params.arguments || {}; try { if (toolName === 'generate_images') { const parsedArgs = generateImagesParamsZodSchema.parse(toolArgs); return await handleGenerateImages(parsedArgs, swarmui); } else if (toolName === 'list_models') { const parsedArgs = listModelsParamsZodSchema.parse(toolArgs); return await handleListModels(parsedArgs, swarmui); } else if (toolName === 'list_schedulers') { return await handleListSchedulers(swarmui); } else if (toolName === 'get_system_status') { return await handleGetSystemStatus(swarmui); } else { console.error(`[MCP CallTool] Unknown tool called: ${toolName}`); throw new McpError(ErrorCode.MethodNotFound, `Tool ${toolName} not found.`); } } catch (error) { console.error(`[MCP CallTool] Error calling tool ${toolName}:`, error); if (error instanceof z.ZodError) { throw new McpError(ErrorCode.InvalidParams, `Invalid parameters for ${toolName}: ${error.errors.map(e => `${e.path.join('.')}: ${e.message}`).join(", ")}`); } if (error instanceof McpError) { throw error; } const message = error instanceof Error ? error.message : String(error); return { content: [{ type: "text", text: `Error executing tool ${toolName}: ${message}` }], isError: true }; } }); const transport = new StdioServerTransport(); server.connect(transport); console.error('SwarmUI MCP Server (using low-level API) connected via stdio. Ready for client.'); // Changed return server; } //# sourceMappingURL=server.js.map