UNPKG

@juspay/neurolink

Version:

Universal AI Development Platform with working MCP integration, multi-provider support, voice (TTS/STT/realtime), and professional CLI. 58+ external MCP servers discoverable, multimodal file processing, RAG pipelines. Build, test, and deploy AI applicatio

326 lines (325 loc) 13.1 kB
/** * Vector Query Tool * * Provides semantic search capabilities for RAG pipelines. * Integrates with vector stores and supports metadata filtering and reranking. */ import { randomUUID } from "crypto"; import { z } from "zod"; import { ProviderFactory } from "../../factories/providerFactory.js"; import { withSpan } from "../../telemetry/withSpan.js"; import { tracers } from "../../telemetry/tracers.js"; import { logger } from "../../utils/logger.js"; import { rerank } from "../reranker/reranker.js"; /** * Creates a vector query tool for semantic search * Follows NeuroLink's factory pattern * * @param config - Tool configuration * @param vectorStore - Vector store instance or resolver function * @returns Tool object with execute method */ export function createVectorQueryTool(config, vectorStore) { const { id = `vector-query-${randomUUID().slice(0, 8)}`, description = "Access the knowledge base to find information needed to answer user questions", indexName, embeddingModel, enableFilter = false, includeVectors = false, includeSources = true, topK = 10, reranker: rerankerConfig, providerOptions, } = config; return { name: id, description, parameters: z.object({ query: z .string() .describe("The search query to find relevant information"), ...(enableFilter ? { filter: z .record(z.string(), z.unknown()) .optional() .describe("Metadata filters to narrow down results"), } : {}), topK: z .number() .optional() .describe(`Number of results to return (default: ${topK})`), }), /** * Execute the vector query * @param params - Query parameters * @param context - Optional request context * @returns Query results with relevant context */ execute: async (params, context) => { return withSpan({ name: "neurolink.rag.vectorQuery", tracer: tracers.rag, attributes: { "rag.vector.index": indexName, "rag.vector.top_k": params.topK ?? topK, "rag.vector.query_length": params.query.length, }, }, async (span) => { const startTime = Date.now(); try { // Resolve vector store if it's a function const store = typeof vectorStore === "function" ? vectorStore(context || {}) : vectorStore; // Generate query embedding const embeddingProvider = await ProviderFactory.createProvider(embeddingModel.provider, embeddingModel.modelName); // Check if provider has embed method if (typeof embeddingProvider .embed !== "function") { throw new Error(`Provider ${embeddingModel.provider} does not support embeddings`); } const queryEmbedding = await embeddingProvider.embed(params.query); // Query the vector store let results = await store.query({ indexName, queryVector: queryEmbedding, topK: params.topK || topK, filter: params.filter, includeVectors, ...providerOptions, }); let reranked = false; // Apply reranking if configured if (rerankerConfig && results.length > 0) { const rerankerModel = await ProviderFactory.createProvider(typeof rerankerConfig.model === "object" ? rerankerConfig.model.provider : rerankerConfig.model, typeof rerankerConfig.model === "object" ? rerankerConfig.model.modelName : rerankerConfig.model); const rerankedResults = await rerank(results, params.query, rerankerModel, { weights: rerankerConfig.weights, topK: rerankerConfig.topK, queryEmbedding, }); results = rerankedResults.map((r) => r.result); reranked = true; } // Format results const relevantContext = results .map((r, i) => `[${i + 1}] ${r.metadata?.text || r.text || ""}`) .join("\n\n"); const queryTime = Date.now() - startTime; logger.info("[VectorQueryTool] Query completed", { query: params.query.slice(0, 50), resultsCount: results.length, queryTime, reranked, filtered: !!params.filter, }); span.setAttribute("rag.vector.result_count", results.length); span.setAttribute("rag.vector.reranked", reranked); return { relevantContext, sources: includeSources ? results : [], totalResults: results.length, metadata: { queryTime, reranked, filtered: !!params.filter, }, }; } catch (error) { logger.error("[VectorQueryTool] Query failed", { query: params.query.slice(0, 50), error: error instanceof Error ? error.message : String(error), }); throw error; } }); // end withSpan }, }; } /** * In-memory vector store implementation for testing and development */ export class InMemoryVectorStore { vectors = new Map(); /** * Add vectors to an index */ async upsert(indexName, items) { let index = this.vectors.get(indexName); if (!index) { index = new Map(); this.vectors.set(indexName, index); } for (const item of items) { index.set(item.id, { vector: item.vector, metadata: item.metadata || {}, }); } } /** * Query vectors by similarity */ async query(params) { const { indexName, queryVector, topK = 10, filter, includeVectors = false, } = params; const index = this.vectors.get(indexName); if (!index) { return []; } // Calculate similarities const results = []; for (const [id, data] of index) { // Apply filter if provided if (filter && !this.matchesFilter(data.metadata, filter)) { continue; } const score = this.cosineSimilarity(queryVector, data.vector); results.push({ id, score, metadata: data.metadata, ...(includeVectors ? { vector: data.vector } : {}), }); } // Sort by score descending and take top K results.sort((a, b) => b.score - a.score); return results.slice(0, topK).map((r) => ({ id: r.id, score: r.score, text: r.metadata.text, metadata: r.metadata, ...(includeVectors ? { vector: r.vector } : {}), })); } /** * Delete vectors from an index */ async delete(indexName, ids) { const index = this.vectors.get(indexName); if (!index) { return; } for (const id of ids) { index.delete(id); } } /** * Check if metadata matches filter */ matchesFilter(metadata, filter) { for (const [key, value] of Object.entries(filter)) { if (key.startsWith("$")) { // Logical operators switch (key) { case "$and": if (!value.every((f) => this.matchesFilter(metadata, f))) { return false; } break; case "$or": if (!value.some((f) => this.matchesFilter(metadata, f))) { return false; } break; case "$not": if (this.matchesFilter(metadata, value)) { return false; } break; } } else { // Field comparison const fieldValue = metadata[key]; if (typeof value === "object" && value !== null) { // Comparison operators const ops = value; if ("$eq" in ops && fieldValue !== ops.$eq) { return false; } if ("$ne" in ops && fieldValue === ops.$ne) { return false; } if ("$gt" in ops && (typeof fieldValue !== "number" || fieldValue <= ops.$gt)) { return false; } if ("$gte" in ops && (typeof fieldValue !== "number" || fieldValue < ops.$gte)) { return false; } if ("$lt" in ops && (typeof fieldValue !== "number" || fieldValue >= ops.$lt)) { return false; } if ("$lte" in ops && (typeof fieldValue !== "number" || fieldValue > ops.$lte)) { return false; } if ("$in" in ops && !ops.$in.includes(fieldValue)) { return false; } if ("$nin" in ops && ops.$nin.includes(fieldValue)) { return false; } if ("$exists" in ops && (ops.$exists ? fieldValue === undefined : fieldValue !== undefined)) { return false; } if ("$contains" in ops && (typeof fieldValue !== "string" || !fieldValue.includes(ops.$contains))) { return false; } if ("$regex" in ops) { const pattern = ops.$regex; let regexMatches = false; // Guard against ReDoS: reject excessively long patterns and limit // the tested string length to prevent pathological backtracking. if (pattern.length <= 200) { try { const re = new RegExp(pattern); const testValue = typeof fieldValue === "string" ? fieldValue.slice(0, 10_000) : ""; regexMatches = re.test(testValue); } catch { // Invalid regex pattern — treat as non-match regexMatches = false; } } if (!regexMatches) { return false; } } } else { // Direct equality if (fieldValue !== value) { return false; } } } } return true; } /** * Calculate cosine similarity between two vectors */ cosineSimilarity(a, b) { if (a.length !== b.length) { return 0; } let dotProduct = 0; let normA = 0; let normB = 0; for (let i = 0; i < a.length; i++) { dotProduct += a[i] * b[i]; normA += a[i] * a[i]; normB += b[i] * b[i]; } const denominator = Math.sqrt(normA) * Math.sqrt(normB); return denominator === 0 ? 0 : dotProduct / denominator; } }