@bratcliffe909/mcp-server-segmind
Version:
Model Context Protocol server for Segmind API - Generate images and videos using AI models
263 lines • 11.1 kB
JavaScript
import * as fs from 'fs/promises';
import * as path from 'path';
import { z } from 'zod';
import { modelRegistry, ModelCategory } from '../models/registry.js';
import { imageCache } from '../utils/image-cache.js';
import { logger } from '../utils/logger.js';
import { BaseTool } from './base.js';
const TransformImageSchema = z.object({
image: z.string().describe('Input image as base64 string or URL'),
prompt: z.string().min(1).max(2000).describe('Transformation prompt describing desired changes'),
model: z.string().optional().describe('Model ID to use for transformation'),
negative_prompt: z.string().optional().describe('What to avoid in the transformation'),
strength: z.number().min(0).max(1).default(0.75).describe('Transformation strength (0=no change, 1=complete change)'),
mask: z.string().optional().describe('Mask image for inpainting (base64 or URL)'),
control_type: z.enum(['canny', 'depth', 'pose', 'scribble', 'segmentation']).optional().describe('ControlNet type'),
control_strength: z.number().min(0).max(2).default(1).describe('Control strength for ControlNet'),
seed: z.number().int().optional().describe('Seed for reproducible generation'),
output_format: z.enum(['png', 'jpeg', 'webp']).default('png').describe('Output image format'),
save_location: z.string().optional().describe('Directory path to save the image. Overrides default save location.'),
});
export class TransformImageTool extends BaseTool {
name = 'transform_image';
description = 'Transform existing images using AI models with various control methods';
async execute(params) {
try {
const validated = TransformImageSchema.parse(params);
const imageValidation = await this.validateImageInput(validated.image);
if (!imageValidation.isValid) {
return {
content: [{
type: 'text',
text: `Invalid image input: ${imageValidation.error}`,
}],
isError: true,
};
}
const model = this.selectModel(validated);
if (!model) {
return {
content: [{
type: 'text',
text: 'No suitable model found for image transformation.',
}],
isError: true,
};
}
logger.info(`Selected model ${model.id} for image transformation`);
const modelParams = await this.prepareModelParameters(validated, model, imageValidation);
const paramValidation = modelRegistry.validateModelParameters(model.id, modelParams);
if (!paramValidation.success) {
return {
content: [{
type: 'text',
text: `Invalid parameters for model ${model.id}: ${paramValidation.error}`,
}],
isError: true,
};
}
let saveLocation = validated.save_location;
if (!saveLocation && imageValidation.originalFilePath) {
saveLocation = imageValidation.originalFilePath;
logger.info(`Saving transformed image back to original location: ${saveLocation}`);
}
const result = await this.callModel(model, paramValidation.data, saveLocation);
const content = [...result.content];
content.push({
type: 'text',
text: `\nImage transformed using ${model.name} with strength ${validated.strength}`,
});
return { content };
}
catch (error) {
logger.error('Image transformation failed', { error });
return this.createErrorResponse(error);
}
}
selectModel(params) {
if (params.model) {
const model = modelRegistry.getModel(params.model);
if (model && model.category === ModelCategory.IMAGE_TO_IMAGE) {
return model;
}
logger.warn(`Model ${params.model} not found or not an image-to-image model`);
}
const img2imgModels = modelRegistry.getModelsByCategory(ModelCategory.IMAGE_TO_IMAGE);
if (params.control_type) {
return img2imgModels.find(m => m.id === 'controlnet') || img2imgModels[0];
}
if (params.mask) {
return img2imgModels.find(m => m.id === 'flux-kontext-pro') || img2imgModels[0];
}
return img2imgModels.find(m => m.id === 'flux-kontext-pro') || img2imgModels[0];
}
async prepareModelParameters(params, model, imageValidation) {
const baseParams = {
prompt: params.prompt,
image: await this.processImageInput(params.image, imageValidation),
};
if (params.negative_prompt && model.parameters.shape.negative_prompt) {
baseParams.negative_prompt = params.negative_prompt;
}
if (params.seed !== undefined && model.parameters.shape.seed) {
baseParams.seed = params.seed;
}
switch (model.id) {
case 'controlnet':
if (params.control_type) {
baseParams.control_type = params.control_type;
}
if (params.control_strength !== undefined) {
baseParams.control_strength = params.control_strength;
}
break;
default:
if (params.strength !== undefined && model.parameters.shape.strength) {
baseParams.strength = params.strength;
}
break;
}
if (params.mask) {
const maskValidation = await this.validateImageInput(params.mask);
if (maskValidation.isValid && model.parameters.shape.mask) {
baseParams.mask = await this.processImageInput(params.mask, maskValidation);
}
}
if (model.parameters.shape.output_format) {
baseParams.output_format = params.output_format;
}
if (model.parameters.shape.base64 !== undefined) {
baseParams.base64 = false;
}
return this.mergeWithDefaults(baseParams, model);
}
async validateImageInput(input) {
const isFilePath = (input.match(/^[A-Za-z]:\\/) ||
input.startsWith('/') ||
input.startsWith('~/'));
if (isFilePath) {
try {
let filePath = input;
if (input.startsWith('~/')) {
filePath = input.replace('~', process.env.HOME || process.env.USERPROFILE || '');
}
if (!path.isAbsolute(filePath)) {
return {
isValid: false,
error: 'File path must be absolute',
};
}
await fs.access(filePath);
const imageBuffer = await fs.readFile(filePath);
const base64String = imageBuffer.toString('base64');
const ext = path.extname(filePath).toLowerCase();
const mimeTypes = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp',
'.bmp': 'image/bmp',
};
const mimeType = mimeTypes[ext] || 'image/png';
const cacheId = imageCache.store(base64String, mimeType, filePath);
logger.info(`Automatically converted file path to cache ID: ${cacheId}`);
return {
isValid: true,
format: 'cached',
cacheId,
size: imageBuffer.length,
originalFilePath: filePath,
};
}
catch (error) {
return {
isValid: false,
error: `Failed to read file: ${error instanceof Error ? error.message : 'Unknown error'}`,
};
}
}
if (input.startsWith('img_')) {
const cachedImage = imageCache.get(input);
if (cachedImage) {
return {
isValid: true,
format: 'cached',
cacheId: input,
size: cachedImage.size,
};
}
return {
isValid: false,
error: `Image cache ID ${input} not found or expired. Please use prepare_image again.`,
};
}
if (input.startsWith('http://') || input.startsWith('https://')) {
return {
isValid: true,
format: 'url',
};
}
const base64Regex = /^data:image\/(png|jpeg|jpg|webp|gif);base64,/;
const match = input.match(base64Regex);
if (match) {
try {
const base64Data = input.split(',')[1] || '';
const buffer = Buffer.from(base64Data, 'base64');
const size = buffer.length;
if (size > 10 * 1024 * 1024) {
return {
isValid: false,
error: 'Image size exceeds 10MB limit',
};
}
return {
isValid: true,
format: 'base64',
size,
};
}
catch (error) {
return {
isValid: false,
error: 'Invalid base64 image data',
};
}
}
try {
const buffer = Buffer.from(input, 'base64');
if (buffer.length > 0) {
return {
isValid: true,
format: 'base64',
size: buffer.length,
};
}
}
catch {
}
return {
isValid: false,
error: 'Image must be a valid URL, base64 encoded string, or image cache ID from prepare_image',
};
}
async processImageInput(input, validation) {
if (validation.format === 'cached' && validation.cacheId) {
const cachedImage = imageCache.get(validation.cacheId);
if (cachedImage) {
return cachedImage.base64;
}
throw new Error(`Image cache ID ${validation.cacheId} not found`);
}
if (validation.format === 'url') {
return input;
}
if (input.startsWith('data:image/')) {
const base64Part = input.split(',')[1];
return base64Part || input;
}
return input;
}
}
export const transformImageTool = new TransformImageTool();
//# sourceMappingURL=transform-image.js.map