UNPKG

@bratcliffe909/mcp-server-segmind

Version:

Model Context Protocol server for Segmind API - Generate images and videos using AI models

266 lines 11.8 kB
import * as fs from 'fs/promises'; import * as path from 'path'; import { z } from 'zod'; import { modelRegistry, ModelCategory } from '../models/registry.js'; import { imageCache } from '../utils/image-cache.js'; import { logger } from '../utils/logger.js'; import { BaseTool } from './base.js'; const EnhanceImageSchema = z.object({ image: z.string().describe('Input image as base64 string or URL'), operation: z.enum(['upscale', 'restore', 'remove_background', 'colorize', 'denoise']).describe('Enhancement operation to perform'), model: z.string().optional().describe('Specific model to use for enhancement'), scale: z.enum(['2', '4', '8']).default('4').describe('Upscaling factor (for upscale operation)'), face_enhance: z.boolean().default(false).describe('Enhance faces during upscaling'), return_mask: z.boolean().default(false).describe('Return mask for background removal'), alpha_matting: z.boolean().default(true).describe('Use alpha matting for cleaner edges'), denoise_strength: z.number().min(0).max(1).default(0.5).describe('Denoising strength'), batch_size: z.number().int().min(1).max(10).default(1).describe('Number of images to process'), save_location: z.string().optional().describe('Directory path to save the image. Overrides default save location.'), }); export class EnhanceImageTool extends BaseTool { name = 'enhance_image'; description = 'Enhance images with upscaling, restoration, background removal, and more'; async execute(params) { try { const validated = EnhanceImageSchema.parse(params); const imageValidation = await this.validateImageInput(validated.image); if (!imageValidation.isValid) { return { content: [{ type: 'text', text: `Invalid image input: ${imageValidation.error}`, }], isError: true, }; } const model = this.selectModel(validated); if (!model) { return { content: [{ type: 'text', text: `No suitable model found for ${validated.operation} operation.`, }], isError: true, }; } logger.info(`Selected model ${model.id} for ${validated.operation} operation`); let saveLocation = validated.save_location; if (!saveLocation && imageValidation.originalFilePath) { saveLocation = imageValidation.originalFilePath; logger.info(`Saving enhanced image back to original location: ${saveLocation}`); } const results = []; for (let i = 0; i < validated.batch_size; i++) { if (validated.batch_size > 1) { logger.info(`Processing image ${i + 1} of ${validated.batch_size}`); } const modelParams = await this.prepareModelParameters(validated, model, imageValidation); const paramValidation = modelRegistry.validateModelParameters(model.id, modelParams); if (!paramValidation.success) { return { content: [{ type: 'text', text: `Invalid parameters for model ${model.id}: ${paramValidation.error}`, }], isError: true, }; } const result = await this.callModel(model, paramValidation.data, saveLocation); results.push(...result.content); } results.push({ type: 'text', text: this.generateSummary(validated, model), }); return { content: results }; } catch (error) { logger.error('Image enhancement failed', { error }); return this.createErrorResponse(error); } } selectModel(params) { if (params.model) { const model = modelRegistry.getModel(params.model); if (model && model.category === ModelCategory.IMAGE_ENHANCEMENT) { return model; } logger.warn(`Model ${params.model} not found or not an enhancement model`); } const enhancementModels = modelRegistry.getModelsByCategory(ModelCategory.IMAGE_ENHANCEMENT); switch (params.operation) { case 'upscale': return enhancementModels.find(m => m.id === 'esrgan') || enhancementModels[0]; case 'remove_background': return enhancementModels.find(m => m.id === 'bg-removal') || enhancementModels[0]; case 'restore': return enhancementModels.find(m => m.id === 'face-restoration') || enhancementModels.find(m => m.id === 'esrgan') || enhancementModels[0]; case 'colorize': return enhancementModels.find(m => m.id === 'colorization') || enhancementModels[0]; case 'denoise': return enhancementModels.find(m => m.id === 'denoising') || enhancementModels.find(m => m.id === 'esrgan') || enhancementModels[0]; default: return enhancementModels[0]; } } async prepareModelParameters(params, model, imageValidation) { const baseParams = { image: await this.processImageInput(params.image, imageValidation), }; switch (model.id) { case 'esrgan': baseParams.scale = parseInt(params.scale, 10); baseParams.face_enhance = params.face_enhance; break; case 'bg-removal': baseParams.return_mask = params.return_mask; baseParams.alpha_matting = params.alpha_matting; break; default: if (params.operation === 'upscale' && model.parameters.shape.scale) { baseParams.scale = params.scale; } if (params.operation === 'denoise' && model.parameters.shape.denoise_strength) { baseParams.denoise_strength = params.denoise_strength; } break; } if (model.parameters.shape.base64 !== undefined) { baseParams.base64 = false; } return this.mergeWithDefaults(baseParams, model); } async validateImageInput(input) { const isFilePath = (input.match(/^[A-Za-z]:\\/) || input.startsWith('/') || input.startsWith('~/')); if (isFilePath) { try { let filePath = input; if (input.startsWith('~/')) { filePath = input.replace('~', process.env.HOME || process.env.USERPROFILE || ''); } if (!path.isAbsolute(filePath)) { return { isValid: false, error: 'File path must be absolute', }; } await fs.access(filePath); const imageBuffer = await fs.readFile(filePath); const base64String = imageBuffer.toString('base64'); const ext = path.extname(filePath).toLowerCase(); const mimeTypes = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.webp': 'image/webp', '.bmp': 'image/bmp', }; const mimeType = mimeTypes[ext] || 'image/png'; const cacheId = imageCache.store(base64String, mimeType, filePath); logger.info(`Automatically converted file path to cache ID: ${cacheId}`); return { isValid: true, format: 'cached', cacheId, originalFilePath: filePath, }; } catch (error) { return { isValid: false, error: `Failed to read file: ${error instanceof Error ? error.message : 'Unknown error'}`, }; } } if (input.startsWith('img_')) { const cachedImage = imageCache.get(input); if (cachedImage) { return { isValid: true, format: 'cached', cacheId: input, }; } return { isValid: false, error: `Image cache ID ${input} not found or expired. Please use prepare_image again.`, }; } if (input.startsWith('http://') || input.startsWith('https://')) { return { isValid: true, format: 'url' }; } const base64Regex = /^data:image\/(png|jpeg|jpg|webp);base64,/; if (input.match(base64Regex) || this.isValidBase64(input)) { try { const base64Data = input.includes(',') ? input.split(',')[1] : input; const buffer = Buffer.from(base64Data || '', 'base64'); if (buffer.length > 20 * 1024 * 1024) { return { isValid: false, error: 'Image size exceeds 20MB limit for enhancement operations', }; } return { isValid: true, format: 'base64' }; } catch { return { isValid: false, error: 'Invalid base64 image data', }; } } return { isValid: false, error: 'Image must be a valid URL, base64 encoded string, or image cache ID from prepare_image', }; } isValidBase64(str) { try { Buffer.from(str, 'base64'); return true; } catch { return false; } } async processImageInput(input, validation) { if (validation.format === 'cached' && validation.cacheId) { const cachedImage = imageCache.get(validation.cacheId); if (cachedImage) { return cachedImage.base64; } throw new Error(`Image cache ID ${validation.cacheId} not found`); } if (input.startsWith('http')) { return input; } if (input.startsWith('data:image/')) { const base64Part = input.split(',')[1]; return base64Part || input; } return input; } generateSummary(params, model) { const operations = { upscale: `Upscaled ${params.scale}x${params.face_enhance ? ' with face enhancement' : ''}`, restore: 'Restored and enhanced image quality', remove_background: `Removed background${params.return_mask ? ' (with mask)' : ''}${params.alpha_matting ? ' using alpha matting' : ''}`, colorize: 'Colorized black and white image', denoise: `Denoised with strength ${params.denoise_strength}`, }; const operationText = operations[params.operation] || params.operation; return `\nEnhancement complete: - Operation: ${operationText} - Model: ${model.name} - Batch Size: ${params.batch_size}${params.batch_size > 1 ? ' images' : ' image'}`; } } export const enhanceImageTool = new EnhanceImageTool(); //# sourceMappingURL=enhance-image.js.map