@bratcliffe909/mcp-server-segmind
Version:
Model Context Protocol server for Segmind API - Generate images and videos using AI models
168 lines • 7.5 kB
JavaScript
import { z } from 'zod';
import { modelRegistry, ModelCategory } from '../models/registry.js';
import { logger } from '../utils/logger.js';
import { BaseTool } from './base.js';
const GenerateImageSchema = z.object({
prompt: z.string().min(1).max(4000).describe('Text prompt describing the image to generate'),
model: z.string().optional().describe('Model ID to use for generation'),
negative_prompt: z.string().optional().describe('What to avoid in the generation'),
width: z.number().int().multipleOf(8).min(256).max(2048).optional().describe('Image width'),
height: z.number().int().multipleOf(8).min(256).max(2048).optional().describe('Image height'),
num_images: z.number().int().min(1).max(4).default(1).describe('Number of images to generate'),
seed: z.number().int().optional().describe('Seed for reproducible generation'),
quality: z.enum(['draft', 'standard', 'high']).default('standard').describe('Quality preset'),
style: z.string().optional().describe('Style modifier (e.g., "photorealistic", "anime", "oil painting")'),
save_location: z.string().optional().describe('Directory path to save the image(s). Overrides default save location.'),
});
export class GenerateImageTool extends BaseTool {
name = 'generate_image';
description = 'Generate images from text prompts using various AI models';
async execute(params) {
try {
const validated = GenerateImageSchema.parse(params);
const model = this.selectModel(validated);
if (!model) {
return {
content: [{
type: 'text',
text: 'No suitable model found for image generation.',
}],
isError: true,
};
}
logger.info(`Selected model ${model.id} for image generation`);
const modelParams = this.prepareModelParameters(validated, model);
const paramValidation = modelRegistry.validateModelParameters(model.id, modelParams);
if (!paramValidation.success) {
return {
content: [{
type: 'text',
text: `Invalid parameters for model ${model.id}: ${paramValidation.error}`,
}],
isError: true,
};
}
const results = [];
for (let i = 0; i < validated.num_images; i++) {
logger.info(`Generating image ${i + 1} of ${validated.num_images}`);
const result = await this.callModel(model, paramValidation.data, validated.save_location);
results.push(...result.content);
}
results.push({
type: 'text',
text: `\nGenerated ${validated.num_images} image(s) using ${model.name}`,
});
return { content: results };
}
catch (error) {
logger.error('Image generation failed', { error });
return this.createErrorResponse(error);
}
}
selectModel(params) {
if (params.model) {
const model = modelRegistry.getModel(params.model);
if (model && model.category === ModelCategory.TEXT_TO_IMAGE) {
return model;
}
logger.warn(`Model ${params.model} not found or not a text-to-image model`);
}
const t2iModels = modelRegistry.getModelsByCategory(ModelCategory.TEXT_TO_IMAGE);
if (params.quality === 'high') {
return t2iModels.find(m => m.id === 'sdxl') || t2iModels[0];
}
if (params.style?.includes('anime') || params.style?.includes('enhance')) {
return t2iModels.find(m => m.id === 'fooocus') || t2iModels.find(m => m.id === 'sdxl-lightning') || t2iModels[0];
}
return t2iModels.find(m => m.id === 'sdxl-lightning') || t2iModels[0];
}
prepareModelParameters(params, model) {
const baseParams = {
prompt: this.enhancePrompt(params.prompt, params.style),
};
if (params.negative_prompt && model.parameters.shape.negative_prompt) {
baseParams.negative_prompt = params.negative_prompt;
}
if (model.id === 'fooocus') {
baseParams.aspect_ratio = this.mapToFoocusAspectRatio(params.width, params.height);
}
else {
if (params.width) {
if (model.parameters.shape.img_width) {
baseParams.img_width = params.width;
}
else {
baseParams.width = params.width;
}
}
if (params.height) {
if (model.parameters.shape.img_height) {
baseParams.img_height = params.height;
}
else {
baseParams.height = params.height;
}
}
}
switch (params.quality) {
case 'draft':
if (model.parameters.shape.num_inference_steps) {
baseParams.num_inference_steps = Math.max(10, (model.defaultParams?.num_inference_steps || 30) / 3);
}
if (model.parameters.shape.steps) {
baseParams.steps = Math.max(20, 30 / 2);
}
break;
case 'high':
if (model.parameters.shape.num_inference_steps) {
baseParams.num_inference_steps = Math.min(150, (model.defaultParams?.num_inference_steps || 30) * 2);
}
if (model.parameters.shape.steps) {
baseParams.steps = Math.min(100, 30 * 2);
}
if (model.parameters.shape.quality) {
baseParams.quality = 'hd';
}
break;
}
if (params.seed !== undefined && model.parameters.shape.seed) {
baseParams.seed = params.seed;
}
if (model.parameters.shape.base64 !== undefined) {
baseParams.base64 = false;
}
return this.mergeWithDefaults(baseParams, model);
}
enhancePrompt(prompt, style) {
if (!style)
return prompt;
const styleKeywords = style.toLowerCase().split(' ');
const promptLower = prompt.toLowerCase();
const missingKeywords = styleKeywords.filter(keyword => !promptLower.includes(keyword));
if (missingKeywords.length > 0) {
return `${prompt}, ${missingKeywords.join(' ')} style`;
}
return prompt;
}
mapToFoocusAspectRatio(width, height) {
if (!width || !height)
return '1024*1024';
if (width === 1024 && height === 1024)
return '1024*1024';
if (width === 1152 && height === 896)
return '1152*896';
if (width === 896 && height === 1152)
return '896*1152';
if (width === 1216 && height === 832)
return '1216*832';
if (width === 832 && height === 1216)
return '832*1216';
if (width === 1344 && height === 768)
return '1344*768';
if (width === 768 && height === 1344)
return '768*1344';
return `${width || 1024}*${height || 1024}`;
}
}
export const generateImageTool = new GenerateImageTool();
//# sourceMappingURL=generate-image.js.map