mcp-client
Version:
An MCP client for Node.js
295 lines (254 loc) • 7.72 kB
text/typescript
import {
Client,
ClientOptions,
} from "@modelcontextprotocol/sdk/client/index.js";
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
import { StdioClientTransport, getDefaultEnvironment } from "@modelcontextprotocol/sdk/client/stdio.js";
import {
CompleteRequest,
CompleteResult,
GetPromptRequest,
GetPromptResult,
Implementation,
ListPromptsResultSchema,
ListResourcesResultSchema,
ListToolsResultSchema,
LoggingLevel,
LoggingMessageNotificationSchema,
Progress,
Prompt,
ReadResourceRequest,
ReadResourceResult,
Resource,
ResourceTemplate,
Tool,
type CallToolResult,
} from "@modelcontextprotocol/sdk/types.js";
import EventEmitter from "events";
import { z } from "zod";
import { StrictEventEmitter } from "strict-event-emitter-types";
import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js";
export { ErrorCode, McpError } from "@modelcontextprotocol/sdk/types.js";
/**
* Callback for progress notifications.
*/
type ProgressCallback = (progress: Progress) => void;
type RequestOptions = {
/**
* If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked.
*/
onProgress?: ProgressCallback;
/**
* Can be used to cancel an in-flight request. This will cause an AbortError to be raised from request().
*/
signal?: AbortSignal;
/**
* A timeout (in milliseconds) for this request. If exceeded, an McpError with code `RequestTimeout` will be raised from request().
*
* If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout.
*/
timeout?: number;
};
const transformRequestOptions = (requestOptions: RequestOptions) => {
return {
onprogress: requestOptions.onProgress,
signal: requestOptions.signal,
timeout: requestOptions.timeout,
};
};
type LoggingMessageNotification = {
[key: string]: unknown;
level: LoggingLevel;
};
type MCPClientEvents = {
loggingMessage: (event: LoggingMessageNotification) => void;
};
const MCPClientEventEmitterBase: {
new (): StrictEventEmitter<EventEmitter, MCPClientEvents>;
} = EventEmitter;
class MCPClientEventEmitter extends MCPClientEventEmitterBase {}
async function fetchAllPages<T>(
client: any,
requestParams: { method: string; params?: Record<string, any> },
schema: any,
getItems: (response: any) => T[],
requestOptions?: RequestOptions,
): Promise<T[]> {
const allItems: T[] = [];
let cursor: string | undefined;
do {
// Clone the params to avoid modifying the original object
const params = { ...(requestParams.params || {}) };
// Add cursor to params if it exists
if (cursor) {
params.cursor = cursor;
}
// Make the request
const response = await client.request(
{ method: requestParams.method, params },
schema,
requestOptions ? transformRequestOptions(requestOptions) : undefined,
);
// Use the getter function to extract items
allItems.push(...getItems(response));
// Update cursor for next iteration
cursor = response.nextCursor;
} while (cursor);
return allItems;
}
export class MCPClient extends MCPClientEventEmitter {
private client: Client;
private transports: Transport[] = [];
constructor(clientInfo: Implementation, options?: ClientOptions) {
super();
this.client = new Client(clientInfo, options);
this.client.setNotificationHandler(
LoggingMessageNotificationSchema,
(message) => {
if (message.method === "notifications/message") {
this.emit("loggingMessage", {
level: message.params.level,
...(message.params.data ?? {}),
});
}
},
);
}
async connect(
options:
| { type: "sse"; url: string }
| {
type: "stdio";
args: string[];
command: string;
env?: Record<string, string>;
},
): Promise<void> {
if (options.type === "sse") {
const transport = new SSEClientTransport(new URL(options.url));
this.transports.push(transport);
await this.client.connect(transport);
} else if (options.type === "stdio") {
let mergedEnv: Record<string, string> | null;
if (options.env !== null && options.env !== undefined) {
mergedEnv = { ...getDefaultEnvironment(), ...options.env };
} else {
mergedEnv = getDefaultEnvironment();
}
const transport = new StdioClientTransport({
command: options.command,
env: mergedEnv,
args: options.args,
});
this.transports.push(transport);
this.client.connect(transport);
} else {
throw new Error(`Unknown transport type`);
}
}
async ping(options?: { requestOptions?: RequestOptions }): Promise<null> {
await this.client.ping(options?.requestOptions);
return null;
}
async getAllTools(options?: {
requestOptions?: RequestOptions;
}): Promise<Tool[]> {
return fetchAllPages(
this.client,
{ method: "tools/list" },
ListToolsResultSchema,
(result) => result.tools,
options?.requestOptions,
);
}
async getAllResources(options?: {
requestOptions?: RequestOptions;
}): Promise<Resource[]> {
return fetchAllPages(
this.client,
{ method: "resources/list" },
ListResourcesResultSchema,
(result) => result.resources,
options?.requestOptions,
);
}
async getAllPrompts(options?: {
requestOptions?: RequestOptions;
}): Promise<Prompt[]> {
return fetchAllPages(
this.client,
{ method: "prompts/list" },
ListPromptsResultSchema,
(result) => result.prompts,
options?.requestOptions,
);
}
async callTool<
TResultSchema extends z.ZodType = z.ZodType<CallToolResult>,
TResult = z.infer<TResultSchema>,
>(
invocation: {
name: string;
arguments?: Record<string, unknown>;
},
options?: {
resultSchema?: TResultSchema;
requestOptions?: RequestOptions;
},
): Promise<TResult> {
return (await this.client.callTool(
invocation,
options?.resultSchema as any,
options?.requestOptions
? transformRequestOptions(options.requestOptions)
: undefined,
)) as TResult;
}
async complete(
params: CompleteRequest["params"],
options?: {
requestOptions?: RequestOptions;
},
): Promise<CompleteResult> {
return await this.client.complete(params, options?.requestOptions);
}
async getResource(
params: ReadResourceRequest["params"],
options?: {
requestOptions?: RequestOptions;
},
): Promise<ReadResourceResult> {
return await this.client.readResource(params, options?.requestOptions);
}
async getPrompt(
params: GetPromptRequest["params"],
options?: {
requestOptions?: RequestOptions;
},
): Promise<GetPromptResult> {
return await this.client.getPrompt(params, options?.requestOptions);
}
async getAllResourceTemplates(options?: {
requestOptions?: RequestOptions;
}): Promise<ResourceTemplate[]> {
let cursor: string | undefined;
const allItems: ResourceTemplate[] = [];
do {
const response = await this.client.listResourceTemplates(
{ cursor },
options?.requestOptions,
);
allItems.push(...response.resourceTemplates);
cursor = response.nextCursor;
} while (cursor);
return allItems;
}
async setLoggingLevel(level: LoggingLevel) {
await this.client.setLoggingLevel(level);
}
async close() {
for (const transport of this.transports) {
await transport.close();
}
}
}