@bratcliffe909/mcp-server-segmind
Version:
Model Context Protocol server for Segmind API - Generate images and videos using AI models
203 lines • 10.1 kB
JavaScript
import { z } from 'zod';
import { modelRegistry, ModelCategory } from '../models/registry.js';
import { OutputType } from '../models/types.js';
import { logger } from '../utils/logger.js';
import { BaseTool } from './base.js';
const SpecializedGenerationSchema = z.object({
type: z.enum(['tts', 'music']).describe('Type of specialized content to generate'),
model: z.string().optional().describe('Specific model to use'),
text: z.string().optional().describe('Text to convert to speech (for TTS type)'),
voice: z.string().optional().describe('Voice selection for TTS (orpheus: tara, dan, josh, emma)'),
speed_factor: z.number().optional().describe('Playback speed (0.5-1.5). 0.5 = slowest, 1.0 = normal, 1.5 = fastest'),
top_p: z.number().optional().describe('Controls word variety (0.1-1.0, higher = rarer words)'),
temperature: z.number().optional().describe('Controls randomness/expressiveness (0.1-2.0)'),
max_new_tokens: z.number().optional().describe('Maximum tokens for TTS (controls audio length - higher = longer audio)'),
repetition_penalty: z.number().optional().describe('Penalty for repeated phrases (1.0-2.0, orpheus only)'),
cfg_scale: z.number().optional().describe('How strictly to follow text (1-5, dia only)'),
cfg_filter_top_k: z.number().optional().describe('Token filtering (10-100, dia only)'),
input_audio: z.string().optional().describe('Base64 audio for voice cloning (dia only)'),
duration: z.number().optional().describe('Duration in seconds for music/audio generation'),
prompt: z.string().optional().describe('General description or style instructions'),
negative_prompt: z.string().optional().describe('What to avoid in generation'),
seed: z.number().int().optional().describe('Seed for reproducible generation'),
num_outputs: z.number().int().min(1).max(4).default(1).describe('Number of variations to generate'),
display_mode: z.enum(['display', 'save', 'both']).default('display').describe('How to return the image: display (show image), save (return base64 for saving), both (show image and provide base64)'),
});
export class SpecializedGenerationTool extends BaseTool {
name = 'specialized_generation';
description = 'Generate audio content including text-to-speech and music';
async execute(params) {
try {
const validated = SpecializedGenerationSchema.parse(params);
const validationResult = this.validateTypeRequirements(validated);
if (!validationResult.isValid) {
return {
content: [{
type: 'text',
text: validationResult.error,
}],
isError: true,
};
}
const model = this.selectModel(validated);
if (!model) {
return {
content: [{
type: 'text',
text: `No suitable model found for ${validated.type} generation.`,
}],
isError: true,
};
}
logger.info(`Selected model ${model.id} for ${validated.type} generation`);
const results = [];
for (let i = 0; i < validated.num_outputs; i++) {
if (validated.num_outputs > 1) {
logger.info(`Generating variation ${i + 1} of ${validated.num_outputs}`);
}
const modelParams = await this.prepareModelParameters(validated, model, i);
const paramValidation = modelRegistry.validateModelParameters(model.id, modelParams);
if (!paramValidation.success) {
return {
content: [{
type: 'text',
text: `Invalid parameters for model ${model.id}: ${paramValidation.error}`,
}],
isError: true,
};
}
const result = await this.callModel(model, paramValidation.data, validated.display_mode);
results.push(...result.content);
}
results.push({
type: 'text',
text: this.generateSummary(validated, model),
});
return { content: results };
}
catch (error) {
logger.error('Specialized generation failed', { error });
return this.createErrorResponse(error);
}
}
validateTypeRequirements(params) {
switch (params.type) {
case 'tts':
if (!params.text) {
return { isValid: false, error: 'Text-to-speech generation requires text parameter' };
}
break;
case 'music':
if (!params.prompt) {
return { isValid: false, error: 'Music generation requires prompt parameter' };
}
break;
}
return { isValid: true };
}
selectModel(params) {
if (params.model) {
const model = modelRegistry.getModel(params.model);
if (model && model.category === ModelCategory.SPECIALIZED_GENERATION) {
return model;
}
logger.warn(`Model ${params.model} not found or not a specialized model`);
}
const specializedModels = modelRegistry.getModelsByCategory(ModelCategory.SPECIALIZED_GENERATION);
switch (params.type) {
case 'tts':
if ((params.text && params.text.includes('[S')) ||
params.speed_factor !== undefined ||
params.input_audio !== undefined ||
params.cfg_scale !== undefined ||
params.cfg_filter_top_k !== undefined) {
return specializedModels.find(m => m.id === 'dia-tts') ||
specializedModels.find(m => m.id === 'orpheus-tts');
}
return specializedModels.find(m => m.id === 'orpheus-tts') ||
specializedModels.find(m => m.id === 'dia-tts');
case 'music':
if (params.duration && params.duration > 30) {
return specializedModels.find(m => m.id === 'minimax-music');
}
return specializedModels.find(m => m.id === 'lyria-2') ||
specializedModels.find(m => m.id === 'minimax-music');
default:
return null;
}
}
async prepareModelParameters(params, model, variationIndex) {
const baseParams = {};
const seed = params.seed !== undefined
? params.seed + variationIndex
: undefined;
switch (model.id) {
case 'dia-tts':
baseParams.text = params.text;
if (params.speed_factor !== undefined)
baseParams.speed_factor = params.speed_factor;
if (params.top_p !== undefined)
baseParams.top_p = params.top_p;
if (params.temperature !== undefined)
baseParams.temperature = params.temperature;
if (params.max_new_tokens !== undefined)
baseParams.max_new_tokens = params.max_new_tokens;
if (params.cfg_scale !== undefined)
baseParams.cfg_scale = params.cfg_scale;
if (params.cfg_filter_top_k !== undefined)
baseParams.cfg_filter_top_k = params.cfg_filter_top_k;
if (params.input_audio)
baseParams.input_audio = params.input_audio;
break;
case 'orpheus-tts':
baseParams.text = params.text;
if (params.voice)
baseParams.voice = params.voice;
if (params.top_p !== undefined)
baseParams.top_p = params.top_p;
if (params.temperature !== undefined)
baseParams.temperature = params.temperature;
if (params.max_new_tokens !== undefined)
baseParams.max_new_tokens = params.max_new_tokens;
if (params.repetition_penalty !== undefined)
baseParams.repetition_penalty = params.repetition_penalty;
break;
case 'lyria-2':
baseParams.prompt = params.prompt;
if (params.duration)
baseParams.duration = params.duration;
if (params.negative_prompt)
baseParams.negative_prompt = params.negative_prompt;
break;
case 'minimax-music':
baseParams.prompt = params.prompt;
if (params.duration)
baseParams.duration = params.duration;
break;
default:
logger.warn(`Unknown model ${model.id} in specialized generation`);
break;
}
if (seed !== undefined && model.parameters.shape.seed) {
baseParams.seed = seed;
}
if (model.parameters.shape.base64 !== undefined) {
if (model.outputType === OutputType.AUDIO || model.outputType === OutputType.VIDEO) {
baseParams.base64 = false;
}
else {
baseParams.base64 = true;
}
}
return this.mergeWithDefaults(baseParams, model);
}
generateSummary(params, model) {
const typeDescriptions = {
tts: `speech audio from text`,
music: `${params.duration || 30}-second music track`,
};
return `\nGenerated ${params.num_outputs} ${typeDescriptions[params.type] || params.type}${params.num_outputs > 1 ? ' variations' : ''} using ${model.name}`;
}
}
export const specializedGenerationTool = new SpecializedGenerationTool();
//# sourceMappingURL=specialized-generation.js.map