UNPKG

@bratcliffe909/mcp-server-segmind

Version:

Model Context Protocol server for Segmind API - Generate images and videos using AI models

263 lines 11.1 kB
import * as fs from 'fs/promises'; import * as path from 'path'; import { z } from 'zod'; import { modelRegistry, ModelCategory } from '../models/registry.js'; import { imageCache } from '../utils/image-cache.js'; import { logger } from '../utils/logger.js'; import { BaseTool } from './base.js'; const TransformImageSchema = z.object({ image: z.string().describe('Input image as base64 string or URL'), prompt: z.string().min(1).max(2000).describe('Transformation prompt describing desired changes'), model: z.string().optional().describe('Model ID to use for transformation'), negative_prompt: z.string().optional().describe('What to avoid in the transformation'), strength: z.number().min(0).max(1).default(0.75).describe('Transformation strength (0=no change, 1=complete change)'), mask: z.string().optional().describe('Mask image for inpainting (base64 or URL)'), control_type: z.enum(['canny', 'depth', 'pose', 'scribble', 'segmentation']).optional().describe('ControlNet type'), control_strength: z.number().min(0).max(2).default(1).describe('Control strength for ControlNet'), seed: z.number().int().optional().describe('Seed for reproducible generation'), output_format: z.enum(['png', 'jpeg', 'webp']).default('png').describe('Output image format'), save_location: z.string().optional().describe('Directory path to save the image. Overrides default save location.'), }); export class TransformImageTool extends BaseTool { name = 'transform_image'; description = 'Transform existing images using AI models with various control methods'; async execute(params) { try { const validated = TransformImageSchema.parse(params); const imageValidation = await this.validateImageInput(validated.image); if (!imageValidation.isValid) { return { content: [{ type: 'text', text: `Invalid image input: ${imageValidation.error}`, }], isError: true, }; } const model = this.selectModel(validated); if (!model) { return { content: [{ type: 'text', text: 'No suitable model found for image transformation.', }], isError: true, }; } logger.info(`Selected model ${model.id} for image transformation`); const modelParams = await this.prepareModelParameters(validated, model, imageValidation); const paramValidation = modelRegistry.validateModelParameters(model.id, modelParams); if (!paramValidation.success) { return { content: [{ type: 'text', text: `Invalid parameters for model ${model.id}: ${paramValidation.error}`, }], isError: true, }; } let saveLocation = validated.save_location; if (!saveLocation && imageValidation.originalFilePath) { saveLocation = imageValidation.originalFilePath; logger.info(`Saving transformed image back to original location: ${saveLocation}`); } const result = await this.callModel(model, paramValidation.data, saveLocation); const content = [...result.content]; content.push({ type: 'text', text: `\nImage transformed using ${model.name} with strength ${validated.strength}`, }); return { content }; } catch (error) { logger.error('Image transformation failed', { error }); return this.createErrorResponse(error); } } selectModel(params) { if (params.model) { const model = modelRegistry.getModel(params.model); if (model && model.category === ModelCategory.IMAGE_TO_IMAGE) { return model; } logger.warn(`Model ${params.model} not found or not an image-to-image model`); } const img2imgModels = modelRegistry.getModelsByCategory(ModelCategory.IMAGE_TO_IMAGE); if (params.control_type) { return img2imgModels.find(m => m.id === 'controlnet') || img2imgModels[0]; } if (params.mask) { return img2imgModels.find(m => m.id === 'flux-kontext-pro') || img2imgModels[0]; } return img2imgModels.find(m => m.id === 'flux-kontext-pro') || img2imgModels[0]; } async prepareModelParameters(params, model, imageValidation) { const baseParams = { prompt: params.prompt, image: await this.processImageInput(params.image, imageValidation), }; if (params.negative_prompt && model.parameters.shape.negative_prompt) { baseParams.negative_prompt = params.negative_prompt; } if (params.seed !== undefined && model.parameters.shape.seed) { baseParams.seed = params.seed; } switch (model.id) { case 'controlnet': if (params.control_type) { baseParams.control_type = params.control_type; } if (params.control_strength !== undefined) { baseParams.control_strength = params.control_strength; } break; default: if (params.strength !== undefined && model.parameters.shape.strength) { baseParams.strength = params.strength; } break; } if (params.mask) { const maskValidation = await this.validateImageInput(params.mask); if (maskValidation.isValid && model.parameters.shape.mask) { baseParams.mask = await this.processImageInput(params.mask, maskValidation); } } if (model.parameters.shape.output_format) { baseParams.output_format = params.output_format; } if (model.parameters.shape.base64 !== undefined) { baseParams.base64 = false; } return this.mergeWithDefaults(baseParams, model); } async validateImageInput(input) { const isFilePath = (input.match(/^[A-Za-z]:\\/) || input.startsWith('/') || input.startsWith('~/')); if (isFilePath) { try { let filePath = input; if (input.startsWith('~/')) { filePath = input.replace('~', process.env.HOME || process.env.USERPROFILE || ''); } if (!path.isAbsolute(filePath)) { return { isValid: false, error: 'File path must be absolute', }; } await fs.access(filePath); const imageBuffer = await fs.readFile(filePath); const base64String = imageBuffer.toString('base64'); const ext = path.extname(filePath).toLowerCase(); const mimeTypes = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.webp': 'image/webp', '.bmp': 'image/bmp', }; const mimeType = mimeTypes[ext] || 'image/png'; const cacheId = imageCache.store(base64String, mimeType, filePath); logger.info(`Automatically converted file path to cache ID: ${cacheId}`); return { isValid: true, format: 'cached', cacheId, size: imageBuffer.length, originalFilePath: filePath, }; } catch (error) { return { isValid: false, error: `Failed to read file: ${error instanceof Error ? error.message : 'Unknown error'}`, }; } } if (input.startsWith('img_')) { const cachedImage = imageCache.get(input); if (cachedImage) { return { isValid: true, format: 'cached', cacheId: input, size: cachedImage.size, }; } return { isValid: false, error: `Image cache ID ${input} not found or expired. Please use prepare_image again.`, }; } if (input.startsWith('http://') || input.startsWith('https://')) { return { isValid: true, format: 'url', }; } const base64Regex = /^data:image\/(png|jpeg|jpg|webp|gif);base64,/; const match = input.match(base64Regex); if (match) { try { const base64Data = input.split(',')[1] || ''; const buffer = Buffer.from(base64Data, 'base64'); const size = buffer.length; if (size > 10 * 1024 * 1024) { return { isValid: false, error: 'Image size exceeds 10MB limit', }; } return { isValid: true, format: 'base64', size, }; } catch (error) { return { isValid: false, error: 'Invalid base64 image data', }; } } try { const buffer = Buffer.from(input, 'base64'); if (buffer.length > 0) { return { isValid: true, format: 'base64', size: buffer.length, }; } } catch { } return { isValid: false, error: 'Image must be a valid URL, base64 encoded string, or image cache ID from prepare_image', }; } async processImageInput(input, validation) { if (validation.format === 'cached' && validation.cacheId) { const cachedImage = imageCache.get(validation.cacheId); if (cachedImage) { return cachedImage.base64; } throw new Error(`Image cache ID ${validation.cacheId} not found`); } if (validation.format === 'url') { return input; } if (input.startsWith('data:image/')) { const base64Part = input.split(',')[1]; return base64Part || input; } return input; } } export const transformImageTool = new TransformImageTool(); //# sourceMappingURL=transform-image.js.map