huggingface-mcp-server
Version:
MCP Server for HuggingFace inference endpoints with custom LoRA and story generation
271 lines (241 loc) • 7.68 kB
JavaScript
const express = require('express');
const axios = require('axios');
// Using native fetch in Node.js 18+
const app = express();
app.use(express.json());
// Get API key from command line arguments
const args = process.argv.slice(2);
let apiKey = null;
let port = 3000;
for (let i = 0; i < args.length; i++) {
if (args[i].startsWith('--api-key=')) {
apiKey = args[i].substring('--api-key='.length);
} else if (args[i].startsWith('--port=')) {
port = parseInt(args[i].substring('--port='.length), 10);
}
}
if (!apiKey) {
console.error('Error: API key is required. Use --api-key=YOUR_KEY');
process.exit(1);
}
// Basic logging middleware
app.use((req, res, next) => {
console.log(`${new Date().toISOString()} - ${req.method} ${req.url}`);
next();
});
// Health check endpoint
app.get('/', (req, res) => {
res.json({ message: 'HuggingFace MCP Server is running' });
});
// Tools endpoint
app.post('/v1/tools', (req, res) => {
const tools = [
{
type: 'function',
function: {
name: 'generate_image',
description: 'Generate an image based on a text prompt using Flux model, with optional custom LoRA',
parameters: {
type: 'object',
properties: {
prompt: {
type: 'string',
description: 'Description of the image to generate'
},
lora_name: {
type: 'string',
description: 'Name of the custom LoRA model on HuggingFace (e.g., "username/lora-model-name")'
}
},
required: ['prompt']
}
}
},
{
type: 'function',
function: {
name: 'generate_story',
description: 'Generate a story based on a prompt',
parameters: {
type: 'object',
properties: {
prompt: {
type: 'string',
description: 'Prompt for story generation'
}
},
required: ['prompt']
}
}
}
];
console.log('Tools requested. Returning:', JSON.stringify(tools, null, 2));
res.json({ tools });
});
// Chat completions endpoint
app.post('/v1/chat/completions', async (req, res) => {
try {
const request = req.body;
const messages = request.messages;
console.log('Chat completion requested with messages:', JSON.stringify(messages, null, 2));
// Check if the last message contains tool calls
const lastMessage = messages[messages.length - 1];
if (lastMessage.role === 'assistant' && lastMessage.tool_calls) {
const results = await handleToolCalls(lastMessage.tool_calls);
return res.json({
id: 'chatcmpl-456',
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: 'hf-mcp-server',
choices: [{
index: 0,
message: {
role: 'assistant',
content: 'I\'ve processed your request.',
tool_calls: []
},
finish_reason: 'tool_calls'
}],
tool_responses: results
});
}
// If we reach here, return a response asking the user to choose a tool
res.json({
id: 'chatcmpl-123',
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: 'hf-mcp-server',
choices: [{
index: 0,
message: {
role: 'assistant',
tool_calls: [
{
id: 'call_abc123',
type: 'function',
function: {
name: 'generate_image',
arguments: JSON.stringify({
prompt: 'I need a description from you for the image you\'d like to generate.'
})
}
}
]
},
finish_reason: 'tool_calls'
}]
});
} catch (error) {
console.error('Error in chat completions:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
async function handleToolCalls(toolCalls) {
const results = [];
for (const toolCall of toolCalls) {
const functionName = toolCall.function.name;
const arguments_ = JSON.parse(toolCall.function.arguments);
if (functionName === 'generate_image') {
const result = await generateImage(arguments_);
results.push({ tool_call_id: toolCall.id, role: 'tool', content: result });
} else if (functionName === 'generate_story') {
const result = await generateStory(arguments_);
results.push({ tool_call_id: toolCall.id, role: 'tool', content: result });
}
}
return results;
}
async function generateImage(input) {
// Default to Stable Diffusion Flux model
const model = 'stabilityai/stable-diffusion-xl-base-1.0';
try {
const payload = {
inputs: input.prompt
};
// If LoRA is specified, add it to the payload
if (input.lora_name) {
payload.parameters = {
custom_lora: input.lora_name
};
}
console.log(`Generating image with prompt: "${input.prompt}"`);
if (input.lora_name) {
console.log(`Using custom LoRA: ${input.lora_name}`);
}
const response = await axios.post(
`https://api-inference.huggingface.co/models/${model}`,
payload,
{
headers: {
Authorization: `Bearer ${apiKey}`,
'Content-Type': 'application/json'
},
responseType: 'arraybuffer'
}
);
// For image generation, the response is binary data
if (response.status === 200) {
const base64Image = Buffer.from(response.data).toString('base64');
return `Generated image for prompt: "${input.prompt}"\nImage data: data:image/jpeg;base64,${base64Image}`;
} else {
return `Error generating image: ${response.statusText}`;
}
} catch (error) {
console.error('Error generating image:', error.message);
return `Error: ${error.message}`;
}
}
async function generateStory(input) {
const model = 'mistralai/Mistral-7B-Instruct-v0.2';
try {
console.log(`Generating story with prompt: "${input.prompt}"`);
const response = await axios.post(
`https://api-inference.huggingface.co/models/${model}`,
{
inputs: `Generate a short story based on this prompt: ${input.prompt}`,
parameters: {
max_length: 1000,
temperature: 0.7,
top_p: 0.9
}
},
{
headers: {
Authorization: `Bearer ${apiKey}`,
'Content-Type': 'application/json'
}
}
);
if (response.status === 200) {
return response.data[0].generated_text;
} else {
return `Error generating story: ${response.statusText}`;
}
} catch (error) {
console.error('Error generating story:', error.message);
return `Error: ${error.message}`;
}
}
// Start the server - listen on all interfaces (0.0.0.0) to ensure it's accessible
app.listen(port, '0.0.0.0', () => {
console.log(`Server running on port ${port}`);
console.log(`API Key is ${apiKey ? 'provided' : 'missing'}`);
console.log('Server is ready to handle requests!');
// Send a preliminary request to the tools endpoint to verify it's working
console.log('Performing self-test...');
fetch(`http://localhost:${port}/v1/tools`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: '{}'
})
.then(response => response.json())
.then(data => {
console.log('Self-test successful, server returned tools:', JSON.stringify(data, null, 2));
})
.catch(error => {
console.error('Self-test failed:', error);
});
});