UNPKG

@bratcliffe909/mcp-server-segmind

Version:

Model Context Protocol server for Segmind API - Generate images and videos using AI models

203 lines 10.1 kB
import { z } from 'zod'; import { modelRegistry, ModelCategory } from '../models/registry.js'; import { OutputType } from '../models/types.js'; import { logger } from '../utils/logger.js'; import { BaseTool } from './base.js'; const SpecializedGenerationSchema = z.object({ type: z.enum(['tts', 'music']).describe('Type of specialized content to generate'), model: z.string().optional().describe('Specific model to use'), text: z.string().optional().describe('Text to convert to speech (for TTS type)'), voice: z.string().optional().describe('Voice selection for TTS (orpheus: tara, dan, josh, emma)'), speed_factor: z.number().optional().describe('Playback speed (0.5-1.5). 0.5 = slowest, 1.0 = normal, 1.5 = fastest'), top_p: z.number().optional().describe('Controls word variety (0.1-1.0, higher = rarer words)'), temperature: z.number().optional().describe('Controls randomness/expressiveness (0.1-2.0)'), max_new_tokens: z.number().optional().describe('Maximum tokens for TTS (controls audio length - higher = longer audio)'), repetition_penalty: z.number().optional().describe('Penalty for repeated phrases (1.0-2.0, orpheus only)'), cfg_scale: z.number().optional().describe('How strictly to follow text (1-5, dia only)'), cfg_filter_top_k: z.number().optional().describe('Token filtering (10-100, dia only)'), input_audio: z.string().optional().describe('Base64 audio for voice cloning (dia only)'), duration: z.number().optional().describe('Duration in seconds for music/audio generation'), prompt: z.string().optional().describe('General description or style instructions'), negative_prompt: z.string().optional().describe('What to avoid in generation'), seed: z.number().int().optional().describe('Seed for reproducible generation'), num_outputs: z.number().int().min(1).max(4).default(1).describe('Number of variations to generate'), display_mode: z.enum(['display', 'save', 'both']).default('display').describe('How to return the image: display (show image), save (return base64 for saving), both (show image and provide base64)'), }); export class SpecializedGenerationTool extends BaseTool { name = 'specialized_generation'; description = 'Generate audio content including text-to-speech and music'; async execute(params) { try { const validated = SpecializedGenerationSchema.parse(params); const validationResult = this.validateTypeRequirements(validated); if (!validationResult.isValid) { return { content: [{ type: 'text', text: validationResult.error, }], isError: true, }; } const model = this.selectModel(validated); if (!model) { return { content: [{ type: 'text', text: `No suitable model found for ${validated.type} generation.`, }], isError: true, }; } logger.info(`Selected model ${model.id} for ${validated.type} generation`); const results = []; for (let i = 0; i < validated.num_outputs; i++) { if (validated.num_outputs > 1) { logger.info(`Generating variation ${i + 1} of ${validated.num_outputs}`); } const modelParams = await this.prepareModelParameters(validated, model, i); const paramValidation = modelRegistry.validateModelParameters(model.id, modelParams); if (!paramValidation.success) { return { content: [{ type: 'text', text: `Invalid parameters for model ${model.id}: ${paramValidation.error}`, }], isError: true, }; } const result = await this.callModel(model, paramValidation.data, validated.display_mode); results.push(...result.content); } results.push({ type: 'text', text: this.generateSummary(validated, model), }); return { content: results }; } catch (error) { logger.error('Specialized generation failed', { error }); return this.createErrorResponse(error); } } validateTypeRequirements(params) { switch (params.type) { case 'tts': if (!params.text) { return { isValid: false, error: 'Text-to-speech generation requires text parameter' }; } break; case 'music': if (!params.prompt) { return { isValid: false, error: 'Music generation requires prompt parameter' }; } break; } return { isValid: true }; } selectModel(params) { if (params.model) { const model = modelRegistry.getModel(params.model); if (model && model.category === ModelCategory.SPECIALIZED_GENERATION) { return model; } logger.warn(`Model ${params.model} not found or not a specialized model`); } const specializedModels = modelRegistry.getModelsByCategory(ModelCategory.SPECIALIZED_GENERATION); switch (params.type) { case 'tts': if ((params.text && params.text.includes('[S')) || params.speed_factor !== undefined || params.input_audio !== undefined || params.cfg_scale !== undefined || params.cfg_filter_top_k !== undefined) { return specializedModels.find(m => m.id === 'dia-tts') || specializedModels.find(m => m.id === 'orpheus-tts'); } return specializedModels.find(m => m.id === 'orpheus-tts') || specializedModels.find(m => m.id === 'dia-tts'); case 'music': if (params.duration && params.duration > 30) { return specializedModels.find(m => m.id === 'minimax-music'); } return specializedModels.find(m => m.id === 'lyria-2') || specializedModels.find(m => m.id === 'minimax-music'); default: return null; } } async prepareModelParameters(params, model, variationIndex) { const baseParams = {}; const seed = params.seed !== undefined ? params.seed + variationIndex : undefined; switch (model.id) { case 'dia-tts': baseParams.text = params.text; if (params.speed_factor !== undefined) baseParams.speed_factor = params.speed_factor; if (params.top_p !== undefined) baseParams.top_p = params.top_p; if (params.temperature !== undefined) baseParams.temperature = params.temperature; if (params.max_new_tokens !== undefined) baseParams.max_new_tokens = params.max_new_tokens; if (params.cfg_scale !== undefined) baseParams.cfg_scale = params.cfg_scale; if (params.cfg_filter_top_k !== undefined) baseParams.cfg_filter_top_k = params.cfg_filter_top_k; if (params.input_audio) baseParams.input_audio = params.input_audio; break; case 'orpheus-tts': baseParams.text = params.text; if (params.voice) baseParams.voice = params.voice; if (params.top_p !== undefined) baseParams.top_p = params.top_p; if (params.temperature !== undefined) baseParams.temperature = params.temperature; if (params.max_new_tokens !== undefined) baseParams.max_new_tokens = params.max_new_tokens; if (params.repetition_penalty !== undefined) baseParams.repetition_penalty = params.repetition_penalty; break; case 'lyria-2': baseParams.prompt = params.prompt; if (params.duration) baseParams.duration = params.duration; if (params.negative_prompt) baseParams.negative_prompt = params.negative_prompt; break; case 'minimax-music': baseParams.prompt = params.prompt; if (params.duration) baseParams.duration = params.duration; break; default: logger.warn(`Unknown model ${model.id} in specialized generation`); break; } if (seed !== undefined && model.parameters.shape.seed) { baseParams.seed = seed; } if (model.parameters.shape.base64 !== undefined) { if (model.outputType === OutputType.AUDIO || model.outputType === OutputType.VIDEO) { baseParams.base64 = false; } else { baseParams.base64 = true; } } return this.mergeWithDefaults(baseParams, model); } generateSummary(params, model) { const typeDescriptions = { tts: `speech audio from text`, music: `${params.duration || 30}-second music track`, }; return `\nGenerated ${params.num_outputs} ${typeDescriptions[params.type] || params.type}${params.num_outputs > 1 ? ' variations' : ''} using ${model.name}`; } } export const specializedGenerationTool = new SpecializedGenerationTool(); //# sourceMappingURL=specialized-generation.js.map