UNPKG

@bratcliffe909/mcp-server-segmind

Version:

Model Context Protocol server for Segmind API - Generate images and videos using AI models

381 lines 16.2 kB
import * as fs from 'fs'; import * as os from 'os'; import * as path from 'path'; import { apiClient } from '../api/client.js'; import { ModelCategory } from '../models/registry.js'; import { OutputType } from '../models/types.js'; import { config } from '../utils/config.js'; import { costTracker } from '../utils/cost-tracker.js'; import { NetworkError, TimeoutError, InsufficientCreditsError, AuthenticationError, mapToSafeError } from '../utils/errors.js'; import { logger } from '../utils/logger.js'; export class BaseTool { context; constructor(context) { this.context = context; } async callModel(model, parameters, saveLocation) { const startTime = Date.now(); try { logger.info(`Calling model ${model.id}`, { model: model.id, category: model.category, requestId: this.context?.requestId, parameters: parameters, }); let timeout = model.estimatedTime * 1000 * 3; if (model.outputType === OutputType.VIDEO) { timeout = Math.max(timeout, 300000); } else if (model.outputType === OutputType.AUDIO) { timeout = Math.max(timeout, 300000); } else if (model.category === ModelCategory.IMAGE_ENHANCEMENT) { timeout = Math.max(timeout, 180000); } logger.info(`Model timeout configured`, { model: model.id, estimatedTime: model.estimatedTime, calculatedTimeout: model.estimatedTime * 1000 * 3, finalTimeout: timeout, timeoutMinutes: (timeout / 1000 / 60).toFixed(2), }); const response = model.outputType === OutputType.IMAGE && model.category !== ModelCategory.IMAGE_ENHANCEMENT ? await apiClient.generateImage(model.id, parameters) : await apiClient.request(model.endpoint, { method: 'POST', body: parameters, timeout, }); const processingTime = Date.now() - startTime; logger.info(`Model call successful`, { model: model.id, processingTime, creditsUsed: response.credits?.used, }); if (response.credits?.used) { costTracker.recordCost(model.id, response.credits.used, { resolution: parameters.img_width && parameters.img_height ? `${parameters.img_width}x${parameters.img_height}` : undefined, quality: parameters.quality, samples: parameters.samples || parameters.num_images, }); } return { content: await this.processModelResponse(response, model, parameters.prompt, saveLocation), model: model.id, creditsUsed: response.credits?.used || model.creditsPerUse, processingTime, metadata: response.metadata, }; } catch (error) { logger.error(`Model call failed`, { model: model.id, error: mapToSafeError(error), processingTime: Date.now() - startTime, }); throw error; } } async processModelResponse(response, model, prompt, saveLocation) { const content = []; logger.debug('Processing model response', { modelId: model.id, outputType: model.outputType, responseKeys: Object.keys(response || {}), dataKeys: Object.keys(response.data || {}), hasImage: !!response.data?.image, hasImages: !!response.data?.images, creditsUsed: response.credits?.used }); switch (model.outputType) { case OutputType.IMAGE: if (response.data?.image) { const mimeType = response.data.mimeType || 'image/png'; let base64Data = response.data.image; if (base64Data.startsWith('data:')) { const match = base64Data.match(/^data:[^;]+;base64,(.+)$/); if (match) { base64Data = match[1]; } } const savedPath = await this.saveImageToFile(base64Data, mimeType, model, prompt, saveLocation); if (savedPath) { content.push({ type: 'text', text: `Image saved to: ${savedPath}`, }); } } else if (response.data?.images) { for (let index = 0; index < response.data.images.length; index++) { const img = response.data.images[index]; const mimeType = response.data.mimeType || 'image/png'; let base64Data = img; if (base64Data.startsWith('data:')) { const match = base64Data.match(/^data:[^;]+;base64,(.+)$/); if (match && match[1]) { base64Data = match[1]; } } const savedPath = await this.saveImageToFile(base64Data, mimeType, model, prompt ? `${prompt}-${index + 1}` : undefined, saveLocation); if (savedPath) { content.push({ type: 'text', text: `Image ${index + 1} saved to: ${savedPath}`, }); } } } else if (response.data?.url) { content.push({ type: 'text', text: `Image generated successfully. View at: ${response.data.url}`, }); } else { logger.warn('Unexpected image response structure', { modelId: model.id, dataKeys: Object.keys(response.data || {}), dataType: typeof response.data, }); content.push({ type: 'text', text: 'Image generated but response format was unexpected. Please check the logs.', }); } break; case OutputType.VIDEO: if (response.data?.video_url) { content.push({ type: 'text', text: `Video generated successfully. View at: ${response.data.video_url}`, }); } else if (response.data?.video) { const videoPath = await this.saveVideoToFile(response.data.video, response.data.mimeType || 'video/mp4', model, saveLocation); if (videoPath) { content.push({ type: 'text', text: `Video saved to: ${videoPath}`, }); } } break; case OutputType.TEXT: if (response.data?.text) { content.push({ type: 'text', text: response.data.text, }); } break; case OutputType.AUDIO: if (response.data?.audio_url) { content.push({ type: 'text', text: `Audio generated successfully. Download: ${response.data.audio_url}`, }); } else if (response.data?.audio) { const audioPath = await this.saveAudioToFile(response.data.audio, 'audio/mpeg', model, saveLocation); if (audioPath) { content.push({ type: 'text', text: `Audio saved to: ${audioPath}`, }); } } break; default: content.push({ type: 'text', text: JSON.stringify(response.data, null, 2), }); } if (response.credits) { content.push({ type: 'text', text: `\n\nCredits used: ${response.credits.used} | Remaining: ${response.credits.remaining}`, }); } return content; } createErrorResponse(error) { const safeError = mapToSafeError(error); let userMessage = 'An error occurred while processing your request.'; if (error instanceof NetworkError) { userMessage = 'Unable to connect to Segmind API. Please check your internet connection.'; } else if (error instanceof TimeoutError) { userMessage = 'The request took too long to process. Please try again.'; } else if (error instanceof InsufficientCreditsError) { userMessage = 'Insufficient credits. Please add more credits to your Segmind account.'; } else if (error instanceof AuthenticationError) { userMessage = 'Authentication failed. Please check that your SEGMIND_API_KEY is valid and properly configured.'; } else if (error instanceof Error) { userMessage = `Error: ${error.message || 'Unknown error occurred'}`; } logger.error('Tool execution error', { error: safeError, userMessage, }); return { content: [ { type: 'text', text: userMessage, }, ], isError: true, }; } async handleLongRunningOperation(model, parameters) { if (model.apiVersion === 'v2') { logger.info(`Starting long-running operation for ${model.id}`); const startResponse = await apiClient.request(model.endpoint, { method: 'POST', body: parameters, }); const jobData = startResponse.data; if (!jobData?.job_id) { throw new Error('No job ID returned from API'); } const jobId = jobData.job_id; return this.pollForCompletion(jobId, model); } else { return this.callModel(model, parameters); } } async pollForCompletion(jobId, model) { const maxAttempts = 60; const pollInterval = 5000; let attempts = 0; const startTime = Date.now(); while (attempts < maxAttempts) { attempts++; try { const statusResponse = await apiClient.request(`/jobs/${jobId}`, { method: 'GET', }); const statusData = statusResponse.data; logger.debug(`Job ${jobId} status: ${statusData?.status}`); if (statusData?.status === 'completed') { const processingTime = Date.now() - startTime; return { content: await this.processModelResponse(statusResponse, model), model: model.id, creditsUsed: statusResponse.credits?.used || model.creditsPerUse, processingTime, metadata: statusResponse.metadata, }; } else if (statusData?.status === 'failed') { throw new Error(statusData?.error || 'Job failed'); } await new Promise(resolve => setTimeout(resolve, pollInterval)); } catch (error) { if (attempts === maxAttempts) { throw new TimeoutError('Job polling timeout'); } logger.warn(`Error polling job ${jobId}, attempt ${attempts}`, { error }); } } throw new TimeoutError('Job did not complete within timeout period'); } mergeWithDefaults(params, model) { return { ...model.defaultParams, ...params, }; } async saveImageToFile(base64Data, mimeType, model, _prompt, saveLocationOverride) { try { const extension = mimeType.split('/')[1] || 'jpg'; let filePath; if (saveLocationOverride && path.extname(saveLocationOverride)) { filePath = saveLocationOverride; const saveDir = path.dirname(filePath); fs.mkdirSync(saveDir, { recursive: true }); } else { const timestamp = Date.now(); const modelName = model.id; const filename = `${modelName}-${timestamp}.${extension}`; const saveDir = saveLocationOverride || config.fileOutput.saveLocation || os.tmpdir(); fs.mkdirSync(saveDir, { recursive: true }); filePath = path.join(saveDir, filename); } const buffer = Buffer.from(base64Data, 'base64'); fs.writeFileSync(filePath, buffer); logger.info('Image saved to file', { path: filePath, size: buffer.length, mimeType }); return filePath; } catch (error) { logger.error('Failed to save image to file', { error }); return null; } } async saveAudioToFile(base64Data, mimeType, model, saveLocationOverride) { try { const extension = mimeType.includes('mpeg') ? 'mp3' : mimeType.includes('wav') ? 'wav' : mimeType.includes('ogg') ? 'ogg' : 'mp3'; const timestamp = Date.now(); const modelName = model.id; const filename = `${modelName}-${timestamp}.${extension}`; const saveDir = saveLocationOverride || config.fileOutput.saveLocation || os.tmpdir(); fs.mkdirSync(saveDir, { recursive: true }); const filePath = path.join(saveDir, filename); const buffer = Buffer.from(base64Data, 'base64'); fs.writeFileSync(filePath, buffer); logger.info('Audio saved to file', { path: filePath, size: buffer.length, mimeType }); return filePath; } catch (error) { logger.error('Failed to save audio to file', { error }); return null; } } async saveVideoToFile(base64Data, mimeType, model, saveLocationOverride) { try { const extension = mimeType.includes('mp4') ? 'mp4' : mimeType.includes('webm') ? 'webm' : mimeType.includes('avi') ? 'avi' : mimeType.includes('mov') ? 'mov' : 'mp4'; const timestamp = Date.now(); const modelName = model.id; const filename = `${modelName}-${timestamp}.${extension}`; const saveDir = saveLocationOverride || config.fileOutput.saveLocation || os.tmpdir(); fs.mkdirSync(saveDir, { recursive: true }); const filePath = path.join(saveDir, filename); const buffer = Buffer.from(base64Data, 'base64'); fs.writeFileSync(filePath, buffer); logger.info('Video saved to file', { path: filePath, size: buffer.length, mimeType }); return filePath; } catch (error) { logger.error('Failed to save video to file', { error }); return null; } } } //# sourceMappingURL=base.js.map