@antl3x/toolrag
Version:
Context-aware tool retrieval for language models - unlock the full potential of LLM function calling without context window limitations or constraints.
287 lines • 12.1 kB
JavaScript
import { createClient } from '@libsql/client';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { log } from '#Utils';
import crypto from 'crypto';
import { z } from 'zod';
import { EmbeddingProviderGoogle } from './EmbeddingProviderGoogle.js';
import { EmbeddingProviderOpenAI } from './EmbeddingProviderOpenAI.js';
import { setupConfig } from './ToolRAGConfig';
const mcpToolSchema = z.object({
name: z.string(),
description: z.string().optional(),
inputSchema: z.object({
type: z.literal('object'),
properties: z.record(z.any()).optional(),
}),
});
class ToolRAG {
_mcpClients = [];
_mcpTools = [];
_toolToClientMap = new Map();
_embeddingProvider = null;
_db = null;
_config;
_log = log('toolreg:ToolRAG');
_db_table_name = () => `tool_embeddings_${this._embeddingProvider?.getName()}`;
constructor(config) {
this._config = setupConfig(config);
}
static async init(config) {
const toolRAG = new ToolRAG(config);
await toolRAG._initEmbeddingProvider();
await toolRAG._initDatabase();
await toolRAG._initMcpServers();
toolRAG._log.info('ToolRAG initialized');
return toolRAG;
}
_initEmbeddingProvider() {
switch (this._config.embeddingProvider) {
case 'openai':
this._embeddingProvider = new EmbeddingProviderOpenAI();
break;
case 'google':
this._embeddingProvider = new EmbeddingProviderGoogle();
break;
default:
throw new Error(`Unsupported embedding provider: ${this._config.embeddingProvider}`);
}
}
async _initMcpServers() {
if (this._config.mcpServers?.length) {
this._log.info(`Initializing with ${this._config.mcpServers.length} MCP servers`);
await Promise.all(this._config.mcpServers.map((server) => this._registerMcpServer(server)));
}
}
_hashTool(tool) {
return crypto.createHash('sha256').update(JSON.stringify(tool)).digest('hex');
}
_ensureInitialized() {
if (!this._db)
throw new Error('Database not initialized');
if (!this._embeddingProvider)
throw new Error('Embedding provider not initialized');
}
async _initDatabase() {
try {
this._db = createClient({ url: this._config.database.url });
const dimensions = this._embeddingProvider?.getDimensions();
const tableName = this._db_table_name();
await this._db.execute(`
CREATE TABLE IF NOT EXISTS ${tableName} (
id INTEGER PRIMARY KEY,
tool_name TEXT NOT NULL,
tool_hash TEXT NOT NULL,
embedding F32_BLOB(${dimensions}) NOT NULL,
embedding_text TEXT NOT NULL,
tool_json TEXT NOT NULL
)
`);
await this._db.execute(`
CREATE INDEX IF NOT EXISTS idx_tool_hash ON ${tableName}(tool_hash)
`);
await this._db.execute(`
CREATE INDEX IF NOT EXISTS idx_tool_embeddings_vector
ON ${tableName}(libsql_vector_idx(embedding))
`);
this._log.info(`Database initialized at ${this._config.database.url}`);
}
catch (error) {
this._log.error('Failed to initialize database:', error);
throw error;
}
}
async _registerMcpServer(url) {
const client = new Client({ name: url, version: '0' });
await client.connect(new SSEClientTransport(new URL(url)));
this._mcpClients.push(client);
const res = await client.listTools();
this._log.info(`Found ${res.tools.length} tools from ${url}`);
this._log.info(res.tools.map((tool) => tool.name).join(', '));
// Add each tool to the toolToClientMap
for (const tool of res.tools) {
this._toolToClientMap.set(tool.name, client);
}
this._mcpTools.push(...res.tools);
await this._refreshToolsEmbeddings();
}
_formatToolText(tool) {
const params = tool.inputSchema.properties
? Object.entries(tool.inputSchema.properties)
.map(([name, schema]) => ` ${name} [${schema.type}]: ${schema?.description || ''}`)
.join('\n')
: '';
const toolName = tool.name.replaceAll(/-|_/g, ' ');
return `${toolName}: ${tool?.description || ''}\n${params}`;
}
async _generateToolsEmbeddings(tools) {
this._ensureInitialized();
// Generate text and embeddings for each tool
return await Promise.all(tools.map(async (tool) => {
const toolText = this._formatToolText(tool);
const embedding = await this._embeddingProvider.getEmbedding(toolText);
return {
tool,
toolName: tool.name,
toolHash: this._hashTool(tool),
embedding,
toolText,
};
}));
}
async _refreshToolsEmbeddings() {
this._ensureInitialized();
this._log.info('Checking for new or updated tools...');
// Find tools that need updating
const toolsWithHashes = this._mcpTools.map((tool) => ({
tool,
hash: this._hashTool(tool),
}));
const existingHashes = await this._db.execute({
sql: `SELECT tool_hash FROM ${this._db_table_name()}`,
});
const hashSet = new Set(existingHashes.rows.map((row) => row.tool_hash));
const toolsToUpdate = toolsWithHashes.filter(({ hash }) => !hashSet.has(hash));
if (toolsToUpdate.length === 0) {
this._log.info('All tools are up-to-date, no new embeddings needed');
return [];
}
this._log.info(`Generating embeddings for ${toolsToUpdate.length} new or updated tools...`);
const newEmbeddings = await this._generateToolsEmbeddings(toolsToUpdate.map(({ tool }) => tool));
try {
// Process each new embedding
for (const { toolName, toolText, toolHash, embedding, tool } of newEmbeddings) {
const toolJson = JSON.stringify(tool);
const embeddingBuffer = new Float32Array(embedding).buffer;
const tableName = this._db_table_name();
// Try update first, then insert if not exists
const updateResult = await this._db.execute({
sql: `UPDATE ${tableName}
SET tool_hash = ?, embedding = ?, tool_json = ?, embedding_text = ?
WHERE tool_name = ?`,
args: [toolHash, embeddingBuffer, toolJson, toolText, toolName],
});
if (!updateResult.rowsAffected) {
await this._db.execute({
sql: `INSERT INTO ${tableName}
(tool_name, tool_hash, embedding, tool_json, embedding_text)
VALUES (?, ?, ?, ?, ?)`,
args: [toolName, toolHash, embeddingBuffer, toolJson, toolText],
});
}
}
this._log.info(`Successfully updated embeddings for ${newEmbeddings.length} tools`);
return newEmbeddings;
}
catch (error) {
this._log.error('Error updating embeddings:', error);
throw error;
}
}
async _pruneMissingTools() {
this._ensureInitialized();
this._log.info('Pruning missing tools from database...');
try {
const tableName = this._db_table_name();
// Find tools to remove
const dbToolsResult = await this._db.execute({
sql: `SELECT tool_name FROM ${tableName}`,
});
const dbToolNames = dbToolsResult.rows.map((row) => row.tool_name);
const currentToolNames = this._mcpTools.map((tool) => tool.name);
const toolsToRemove = dbToolNames.filter((name) => !currentToolNames.includes(name));
if (toolsToRemove.length === 0) {
this._log.info('No tools to prune, database is in sync');
return 0;
}
// Remove tools in a single transaction
this._log.info(`Found ${toolsToRemove.length} tools to remove from database`);
for (const toolName of toolsToRemove) {
await this._db.execute({
sql: `DELETE FROM ${tableName} WHERE tool_name = ?`,
args: [toolName],
});
}
this._log.info(`Successfully pruned ${toolsToRemove.length} tools`);
return toolsToRemove.length;
}
catch (error) {
this._log.error('Error pruning missing tools:', error);
throw error;
}
}
async _findSimilarToolsByVector(query) {
this._ensureInitialized();
// Generate embedding for the query
const queryEmbedding = await this._embeddingProvider.getEmbedding(query);
const queryEmbeddingBuffer = new Float32Array(queryEmbedding).buffer;
const tableName = this._db_table_name();
// Use vector search
const result = await this._db.execute({
sql: `
SELECT te.id, te.tool_name, te.tool_json,
vector_distance_cos(te.embedding, ?) as distance
FROM vector_top_k('idx_tool_embeddings_vector', ?, 40) AS vt
JOIN ${tableName} te ON te.id = vt.id
`,
args: [queryEmbeddingBuffer, queryEmbeddingBuffer],
});
this._log.info(`Found ${result.rows.length} similar tools via vector search`);
// Transform results
return result.rows.map((row) => ({
toolName: row.tool_name,
relevance: 1 - row.distance,
tool: JSON.parse(row.tool_json),
}));
}
_convertToOpenAIFunction(tool, relevance) {
return {
type: 'function',
name: tool.name,
description: tool.description || '',
parameters: {
type: 'object',
properties: tool.inputSchema.properties || {},
required: Object.entries(tool.inputSchema.properties || {})
.filter(([_, schema]) => !schema.optional)
.map(([key]) => key),
additionalProperties: false,
},
strict: true,
relevance,
};
}
async listTools(query, options) {
const { relevanceThreshold = 0.15 } = options || {};
this._ensureInitialized();
// Check if we have tools in the database
const count = await this._db.execute(`SELECT COUNT(*) as count FROM ${this._db_table_name()}`);
const toolCount = count.rows[0].count || 0;
if (toolCount === 0) {
this._log.warn('No tool embeddings found in database, storing them now');
await this._refreshToolsEmbeddings();
}
// Find similar tools
const similarTools = await this._findSimilarToolsByVector(query);
// Filter by relevance and convert to OpenAI format
return similarTools
.filter(({ relevance }) => relevance >= relevanceThreshold)
.map(({ tool, relevance }) => this._convertToOpenAIFunction(tool, relevance));
}
async callTool(toolName, input) {
this._ensureInitialized();
const tool = this._mcpTools.find((t) => t.name === toolName);
if (!tool)
throw new Error(`Tool ${toolName} not found`);
const client = this._toolToClientMap.get(toolName);
if (!client)
throw new Error(`MCP client for tool ${toolName} not found`);
const res = await client.callTool({
name: toolName,
arguments: input,
});
return res;
}
}
export default ToolRAG;
//# sourceMappingURL=ToolRAG.js.map