UNPKG

@wong2/mcp-cli

Version:

A CLI inspector for the Model Context Protocol

287 lines (264 loc) 9.34 kB
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { LoggingMessageNotificationSchema } from '@modelcontextprotocol/sdk/types.js' import getPort from 'get-port' import { isEmpty } from 'lodash-es' import crypto from 'node:crypto' import { existsSync } from 'node:fs' import { readFile } from 'node:fs/promises' import prompts from 'prompts' import colors from 'yoctocolors' import { OAuthCallbackServer } from './oauth/callback.js' import { McpOAuthClientProvider } from './oauth/provider.js' import { createSpinner, getClaudeConfigPath, logger, populateURITemplateParts, prettyPrint, readJSONSchemaInputs, readPromptArgumentInputs, } from './utils.js' async function createClient() { const client = new Client({ name: 'mcp-cli', version: '1.0.0' }, { capabilities: {} }) client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { logger.debug('[server log]:', notification.params.data) }) return client } async function listPrimitives(client) { const capabilities = client.getServerCapabilities() const primitives = [] const promises = [] if (capabilities.resources) { promises.push( client.listResources().then(({ resources }) => { resources.forEach((item) => primitives.push({ type: 'resource', value: item })) }), ) promises.push( client.listResourceTemplates().then(({ resourceTemplates }) => { resourceTemplates.forEach((item) => primitives.push({ type: 'resource-template', value: item, }), ) }), ) } if (capabilities.tools) { promises.push( client.listTools().then(({ tools }) => { tools.forEach((item) => primitives.push({ type: 'tool', value: item })) }), ) } if (capabilities.prompts) { promises.push( client.listPrompts().then(({ prompts }) => { prompts.forEach((item) => primitives.push({ type: 'prompt', value: item })) }), ) } await Promise.all(promises) return primitives } async function connectServer(transport) { const spinner = createSpinner('Connecting to server...') let client try { client = await createClient() await client.connect(transport) } catch (err) { spinner.stop() throw err } const primitives = await listPrimitives(client) spinner.success(`Connected, server capabilities: ${Object.keys(client.getServerCapabilities()).join(', ')}`) while (true) { const { primitive } = await prompts( { name: 'primitive', type: 'autocomplete', message: 'Pick a primitive', choices: primitives.map((p) => ({ title: colors.bold(p.type + '(' + p.value.name + ')'), description: p.value.description, value: p, })), }, { onCancel: async () => { await client.close() process.exit(0) }, }, ) let result let spinner if (primitive.type === 'resource') { spinner = createSpinner(`Reading resource ${primitive.value.uri}...`) result = await client.readResource({ uri: primitive.value.uri }).catch((err) => { spinner.error(err.message) spinner = undefined }) } else if (primitive.type === 'resource-template') { const expanded = await populateURITemplateParts(primitive.value.uriTemplate) if (expanded !== null) { spinner = createSpinner(`Reading resource ${expanded}...`) result = await client.readResource({ uri: expanded }).catch((err) => { spinner.error(err.message) spinner = undefined }) } else { logger.log('\n') } } else if (primitive.type === 'tool') { const args = await readJSONSchemaInputs(primitive.value.inputSchema) spinner = createSpinner(`Using tool ${primitive.value.name}...`) result = await client.callTool({ name: primitive.value.name, arguments: args }).catch((err) => { spinner.error(err.message) spinner = undefined }) } else if (primitive.type === 'prompt') { const args = await readPromptArgumentInputs(primitive.value.arguments) spinner = createSpinner(`Using prompt ${primitive.value.name}...`) result = await client.getPrompt({ name: primitive.value.name, arguments: args }).catch((err) => { spinner.error(err.message) spinner = undefined }) } if (spinner) { spinner.success() } if (result) { prettyPrint(result) logger.log('\n') } } } async function readConfig(configFilePath, { silent = false } = {}) { if (!configFilePath || !existsSync(configFilePath)) { throw new Error(`Config file not found: ${configFilePath}`) } if (silent) { const config = await readFile(configFilePath, 'utf-8') return JSON.parse(config) } const spinner = createSpinner(`Loading config from ${configFilePath}`) const config = await readFile(configFilePath, 'utf-8') spinner.success() return JSON.parse(config) } async function pickServer(config) { const { server } = await prompts({ name: 'server', type: 'autocomplete', message: 'Pick a server', choices: Object.keys(config.mcpServers).map((s) => ({ title: s, value: s, })), }) return server } export async function runWithCommand(command, args, env) { const transport = new StdioClientTransport({ command, args, env }) try { await connectServer(transport) } finally { await transport.close() } } export async function runWithConfigNonInteractive(configPath, serverName, command, target, argsString) { try { const defaultConfigFile = getClaudeConfigPath() const config = await readConfig(configPath || defaultConfigFile, { silent: true }) if (!config.mcpServers || isEmpty(config.mcpServers)) { throw new Error('No mcp servers found in config') } const serverConfig = config.mcpServers[serverName] if (!serverConfig) { throw new Error(`Server '${serverName}' not found in config`) } if (serverConfig.env) { serverConfig.env = { ...serverConfig.env, PATH: process.env.PATH } } const transport = new StdioClientTransport(serverConfig) const client = await createClient() await client.connect(transport) let result let args = {} if (argsString) { try { args = JSON.parse(argsString) } catch (err) { throw new Error(`Invalid JSON in --args: ${err.message}`) } } if (command === 'call-tool') { result = await client.callTool({ name: target, arguments: args }) } else if (command === 'read-resource') { result = await client.readResource({ uri: target }) } else if (command === 'get-prompt') { result = await client.getPrompt({ name: target, arguments: args }) } await client.close() console.log(JSON.stringify(result, null, 2)) } catch (err) { console.error(JSON.stringify({ error: err.message }, null, 2)) process.exit(1) } } export async function runWithConfig(configPath) { const defaultConfigFile = getClaudeConfigPath() const config = await readConfig(configPath || defaultConfigFile) if (!config.mcpServers || isEmpty(config.mcpServers)) { throw new Error('No mcp servers found in config') } const server = await pickServer(config) const serverConfig = config.mcpServers[server] if (serverConfig.env) { serverConfig.env = { ...serverConfig.env, PATH: process.env.PATH } } const transport = new StdioClientTransport(serverConfig) try { await connectServer(transport) } finally { await transport.close() } } async function connectRemoteServer(uri, initialTransport) { const oauthConfig = { port: await getPort({ port: 49153 }), path: '/oauth/callback' } const createTransport = () => { const serverId = crypto.createHash('sha256').update(uri).digest('hex') const oauthRedirectUrl = `http://127.0.0.1:${oauthConfig.port}${oauthConfig.path}` const authProvider = new McpOAuthClientProvider(serverId, oauthRedirectUrl) return initialTransport(authProvider) } const transport = createTransport() try { await connectServer(transport) } catch (err) { if (!(err instanceof UnauthorizedError)) { throw err } const spinner = createSpinner('Waiting for authorization...') const callbackServer = new OAuthCallbackServer() const authCode = await callbackServer.listenForCode(oauthConfig.port, oauthConfig.path) await transport.finishAuth(authCode) spinner.success('Authorization successful') // connect again with a new transport await connectServer(createTransport()) } } export async function runWithSSE(uri) { await connectRemoteServer(uri, (authProvider) => new SSEClientTransport(new URL(uri), { authProvider })) } export async function runWithURL(uri) { await connectRemoteServer(uri, (authProvider) => new StreamableHTTPClientTransport(new URL(uri), { authProvider })) }