quantum-cli-core
Version:
Quantum CLI Core - Multi-LLM Collaboration System
353 lines • 14.3 kB
JavaScript
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, afterEach, } from 'vitest';
import { ToolRegistry, sanitizeParameters, } from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { Config, ApprovalMode } from '../config/config.js';
import { BaseTool } from './tools.js';
import { mcpToTool, Type, } from '@google/genai';
import { spawn } from 'node:child_process';
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
// Mock ./mcp-client.js to control its behavior within tool-registry tests
vi.mock('./mcp-client.js', () => ({
discoverMcpTools: mockDiscoverMcpTools,
}));
// Mock node:child_process
vi.mock('node:child_process', async () => {
const actual = await vi.importActual('node:child_process');
return {
...actual,
execSync: vi.fn(),
spawn: vi.fn(),
};
});
// Mock MCP SDK Client and Transports
const mockMcpClientConnect = vi.fn();
const mockMcpClientOnError = vi.fn();
const mockStdioTransportClose = vi.fn();
const mockSseTransportClose = vi.fn();
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
const MockClient = vi.fn().mockImplementation(() => ({
connect: mockMcpClientConnect,
set onerror(handler) {
mockMcpClientOnError(handler);
},
}));
return { Client: MockClient };
});
vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
const MockStdioClientTransport = vi.fn().mockImplementation(() => ({
stderr: {
on: vi.fn(),
},
close: mockStdioTransportClose,
}));
return { StdioClientTransport: MockStdioClientTransport };
});
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
const MockSSEClientTransport = vi.fn().mockImplementation(() => ({
close: mockSseTransportClose,
}));
return { SSEClientTransport: MockSSEClientTransport };
});
// Mock @google/genai mcpToTool
vi.mock('@google/genai', async () => {
const actualGenai = await vi.importActual('@google/genai');
return {
...actualGenai,
mcpToTool: vi.fn().mockImplementation(() => ({
tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }),
callTool: vi.fn(),
})),
};
});
// Helper to create a mock CallableTool for specific test needs
const createMockCallableTool = (toolDeclarations) => ({
tool: vi.fn().mockResolvedValue({ functionDeclarations: toolDeclarations }),
callTool: vi.fn(),
});
class MockTool extends BaseTool {
constructor(name = 'mock-tool', description = 'A mock tool') {
super(name, name, description, {
type: 'object',
properties: {
param: { type: 'string' },
},
required: ['param'],
});
}
async execute(params) {
return {
llmContent: `Executed with ${params.param}`,
returnDisplay: `Executed with ${params.param}`,
};
}
}
const baseConfigParams = {
cwd: '/tmp',
model: 'test-model',
embeddingModel: 'test-embedding-model',
sandbox: undefined,
targetDir: '/test/dir',
debugMode: false,
userMemory: '',
quantumMdFileCount: 0,
approvalMode: ApprovalMode.DEFAULT,
sessionId: 'test-session-id',
};
describe('ToolRegistry', () => {
let config;
let toolRegistry;
let mockConfigGetToolDiscoveryCommand;
beforeEach(() => {
config = new Config(baseConfigParams);
toolRegistry = new ToolRegistry(config);
vi.spyOn(console, 'warn').mockImplementation(() => { });
vi.spyOn(console, 'error').mockImplementation(() => { });
vi.spyOn(console, 'debug').mockImplementation(() => { });
vi.spyOn(console, 'log').mockImplementation(() => { });
mockMcpClientConnect.mockReset().mockResolvedValue(undefined);
mockStdioTransportClose.mockReset();
mockSseTransportClose.mockReset();
vi.mocked(mcpToTool).mockClear();
vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([]));
mockConfigGetToolDiscoveryCommand = vi.spyOn(config, 'getToolDiscoveryCommand');
vi.spyOn(config, 'getMcpServers');
vi.spyOn(config, 'getMcpServerCommand');
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('registerTool', () => {
it('should register a new tool', () => {
const tool = new MockTool();
toolRegistry.registerTool(tool);
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
});
});
describe('getToolsByServer', () => {
it('should return an empty array if no tools match the server name', () => {
toolRegistry.registerTool(new MockTool());
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
});
it('should return only tools matching the server name', async () => {
const server1Name = 'mcp-server-uno';
const server2Name = 'mcp-server-dos';
const mockCallable = {};
const mcpTool1 = new DiscoveredMCPTool(mockCallable, server1Name, 'server1Name__tool-on-server1', 'd1', {}, 'tool-on-server1');
const mcpTool2 = new DiscoveredMCPTool(mockCallable, server2Name, 'server2Name__tool-on-server2', 'd2', {}, 'tool-on-server2');
const nonMcpTool = new MockTool('regular-tool');
toolRegistry.registerTool(mcpTool1);
toolRegistry.registerTool(mcpTool2);
toolRegistry.registerTool(nonMcpTool);
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
expect(toolsFromServer1).toHaveLength(1);
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
expect(toolsFromServer2).toHaveLength(1);
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
});
});
describe('discoverTools', () => {
it('should sanitize tool parameters during discovery from command', async () => {
const discoveryCommand = 'my-discovery-command';
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
const unsanitizedToolDeclaration = {
name: 'tool-with-bad-format',
description: 'A tool with an invalid format property',
parameters: {
type: Type.OBJECT,
properties: {
some_string: {
type: Type.STRING,
format: 'uuid', // This is an unsupported format
},
},
},
};
const mockSpawn = vi.mocked(spawn);
const mockChildProcess = {
stdout: { on: vi.fn() },
stderr: { on: vi.fn() },
on: vi.fn(),
};
mockSpawn.mockReturnValue(mockChildProcess);
// Simulate stdout data
mockChildProcess.stdout.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(Buffer.from(JSON.stringify([
{ function_declarations: [unsanitizedToolDeclaration] },
])));
}
return mockChildProcess;
});
// Simulate process close
mockChildProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(0);
}
return mockChildProcess;
});
await toolRegistry.discoverTools();
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
expect(discoveredTool).toBeDefined();
const registeredParams = discoveredTool.schema
.parameters;
expect(registeredParams.properties?.['some_string']).toBeDefined();
expect(registeredParams.properties?.['some_string']).toHaveProperty('format', undefined);
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
},
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(mcpServerConfigVal, undefined, toolRegistry);
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
},
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(mcpServerConfigVal, undefined, toolRegistry);
});
});
});
describe('sanitizeParameters', () => {
it('should remove unsupported format from a simple string property', () => {
const schema = {
type: Type.OBJECT,
properties: {
name: { type: Type.STRING },
id: { type: Type.STRING, format: 'uuid' },
},
};
sanitizeParameters(schema);
expect(schema.properties?.['id']).toHaveProperty('format', undefined);
expect(schema.properties?.['name']).not.toHaveProperty('format');
});
it('should NOT remove supported format values', () => {
const schema = {
type: Type.OBJECT,
properties: {
date: { type: Type.STRING, format: 'date-time' },
role: {
type: Type.STRING,
format: 'enum',
enum: ['admin', 'user'],
},
},
};
const originalSchema = JSON.parse(JSON.stringify(schema));
sanitizeParameters(schema);
expect(schema).toEqual(originalSchema);
});
it('should handle nested objects recursively', () => {
const schema = {
type: Type.OBJECT,
properties: {
user: {
type: Type.OBJECT,
properties: {
email: { type: Type.STRING, format: 'email' },
},
},
},
};
sanitizeParameters(schema);
expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty('format', undefined);
});
it('should handle arrays of objects', () => {
const schema = {
type: Type.OBJECT,
properties: {
items: {
type: Type.ARRAY,
items: {
type: Type.OBJECT,
properties: {
itemId: { type: Type.STRING, format: 'uuid' },
},
},
},
},
};
sanitizeParameters(schema);
expect(schema.properties?.['items']?.items?.properties?.['itemId']).toHaveProperty('format', undefined);
});
it('should handle schemas with no properties to sanitize', () => {
const schema = {
type: Type.OBJECT,
properties: {
count: { type: Type.NUMBER },
isActive: { type: Type.BOOLEAN },
},
};
const originalSchema = JSON.parse(JSON.stringify(schema));
sanitizeParameters(schema);
expect(schema).toEqual(originalSchema);
});
it('should not crash on an empty or undefined schema', () => {
expect(() => sanitizeParameters({})).not.toThrow();
expect(() => sanitizeParameters(undefined)).not.toThrow();
});
it('should handle cyclic schemas without crashing', () => {
const schema = {
type: Type.OBJECT,
properties: {
name: { type: Type.STRING, format: 'hostname' },
},
};
schema.properties.self = schema;
expect(() => sanitizeParameters(schema)).not.toThrow();
expect(schema.properties.name).toHaveProperty('format', undefined);
});
it('should handle complex nested schemas with cycles', () => {
const userNode = {
type: Type.OBJECT,
properties: {
id: { type: Type.STRING, format: 'uuid' },
name: { type: Type.STRING },
manager: {
type: Type.OBJECT,
properties: {
id: { type: Type.STRING, format: 'uuid' },
},
},
},
};
userNode.properties.reports = {
type: Type.ARRAY,
items: userNode,
};
const schema = {
type: Type.OBJECT,
properties: {
ceo: userNode,
},
};
expect(() => sanitizeParameters(schema)).not.toThrow();
expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty('format', undefined);
expect(schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id']).toHaveProperty('format', undefined);
});
});
//# sourceMappingURL=tool-registry.test.js.map