UNPKG

huggingface-mcp-server

Version:

MCP Server for HuggingFace inference endpoints with custom LoRA and story generation

271 lines (241 loc) 7.68 kB
#!/usr/bin/env node const express = require('express'); const axios = require('axios'); // Using native fetch in Node.js 18+ const app = express(); app.use(express.json()); // Get API key from command line arguments const args = process.argv.slice(2); let apiKey = null; let port = 3000; for (let i = 0; i < args.length; i++) { if (args[i].startsWith('--api-key=')) { apiKey = args[i].substring('--api-key='.length); } else if (args[i].startsWith('--port=')) { port = parseInt(args[i].substring('--port='.length), 10); } } if (!apiKey) { console.error('Error: API key is required. Use --api-key=YOUR_KEY'); process.exit(1); } // Basic logging middleware app.use((req, res, next) => { console.log(`${new Date().toISOString()} - ${req.method} ${req.url}`); next(); }); // Health check endpoint app.get('/', (req, res) => { res.json({ message: 'HuggingFace MCP Server is running' }); }); // Tools endpoint app.post('/v1/tools', (req, res) => { const tools = [ { type: 'function', function: { name: 'generate_image', description: 'Generate an image based on a text prompt using Flux model, with optional custom LoRA', parameters: { type: 'object', properties: { prompt: { type: 'string', description: 'Description of the image to generate' }, lora_name: { type: 'string', description: 'Name of the custom LoRA model on HuggingFace (e.g., "username/lora-model-name")' } }, required: ['prompt'] } } }, { type: 'function', function: { name: 'generate_story', description: 'Generate a story based on a prompt', parameters: { type: 'object', properties: { prompt: { type: 'string', description: 'Prompt for story generation' } }, required: ['prompt'] } } } ]; console.log('Tools requested. Returning:', JSON.stringify(tools, null, 2)); res.json({ tools }); }); // Chat completions endpoint app.post('/v1/chat/completions', async (req, res) => { try { const request = req.body; const messages = request.messages; console.log('Chat completion requested with messages:', JSON.stringify(messages, null, 2)); // Check if the last message contains tool calls const lastMessage = messages[messages.length - 1]; if (lastMessage.role === 'assistant' && lastMessage.tool_calls) { const results = await handleToolCalls(lastMessage.tool_calls); return res.json({ id: 'chatcmpl-456', object: 'chat.completion', created: Math.floor(Date.now() / 1000), model: 'hf-mcp-server', choices: [{ index: 0, message: { role: 'assistant', content: 'I\'ve processed your request.', tool_calls: [] }, finish_reason: 'tool_calls' }], tool_responses: results }); } // If we reach here, return a response asking the user to choose a tool res.json({ id: 'chatcmpl-123', object: 'chat.completion', created: Math.floor(Date.now() / 1000), model: 'hf-mcp-server', choices: [{ index: 0, message: { role: 'assistant', tool_calls: [ { id: 'call_abc123', type: 'function', function: { name: 'generate_image', arguments: JSON.stringify({ prompt: 'I need a description from you for the image you\'d like to generate.' }) } } ] }, finish_reason: 'tool_calls' }] }); } catch (error) { console.error('Error in chat completions:', error); res.status(500).json({ error: 'Internal server error' }); } }); async function handleToolCalls(toolCalls) { const results = []; for (const toolCall of toolCalls) { const functionName = toolCall.function.name; const arguments_ = JSON.parse(toolCall.function.arguments); if (functionName === 'generate_image') { const result = await generateImage(arguments_); results.push({ tool_call_id: toolCall.id, role: 'tool', content: result }); } else if (functionName === 'generate_story') { const result = await generateStory(arguments_); results.push({ tool_call_id: toolCall.id, role: 'tool', content: result }); } } return results; } async function generateImage(input) { // Default to Stable Diffusion Flux model const model = 'stabilityai/stable-diffusion-xl-base-1.0'; try { const payload = { inputs: input.prompt }; // If LoRA is specified, add it to the payload if (input.lora_name) { payload.parameters = { custom_lora: input.lora_name }; } console.log(`Generating image with prompt: "${input.prompt}"`); if (input.lora_name) { console.log(`Using custom LoRA: ${input.lora_name}`); } const response = await axios.post( `https://api-inference.huggingface.co/models/${model}`, payload, { headers: { Authorization: `Bearer ${apiKey}`, 'Content-Type': 'application/json' }, responseType: 'arraybuffer' } ); // For image generation, the response is binary data if (response.status === 200) { const base64Image = Buffer.from(response.data).toString('base64'); return `Generated image for prompt: "${input.prompt}"\nImage data: data:image/jpeg;base64,${base64Image}`; } else { return `Error generating image: ${response.statusText}`; } } catch (error) { console.error('Error generating image:', error.message); return `Error: ${error.message}`; } } async function generateStory(input) { const model = 'mistralai/Mistral-7B-Instruct-v0.2'; try { console.log(`Generating story with prompt: "${input.prompt}"`); const response = await axios.post( `https://api-inference.huggingface.co/models/${model}`, { inputs: `Generate a short story based on this prompt: ${input.prompt}`, parameters: { max_length: 1000, temperature: 0.7, top_p: 0.9 } }, { headers: { Authorization: `Bearer ${apiKey}`, 'Content-Type': 'application/json' } } ); if (response.status === 200) { return response.data[0].generated_text; } else { return `Error generating story: ${response.statusText}`; } } catch (error) { console.error('Error generating story:', error.message); return `Error: ${error.message}`; } } // Start the server - listen on all interfaces (0.0.0.0) to ensure it's accessible app.listen(port, '0.0.0.0', () => { console.log(`Server running on port ${port}`); console.log(`API Key is ${apiKey ? 'provided' : 'missing'}`); console.log('Server is ready to handle requests!'); // Send a preliminary request to the tools endpoint to verify it's working console.log('Performing self-test...'); fetch(`http://localhost:${port}/v1/tools`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: '{}' }) .then(response => response.json()) .then(data => { console.log('Self-test successful, server returned tools:', JSON.stringify(data, null, 2)); }) .catch(error => { console.error('Self-test failed:', error); }); });