mastra-browser-rag
Version:
The Retrieval-Augmented Generation (RAG) module contains document processing and embedding utilities.
113 lines (106 loc) • 3.34 kB
text/typescript
import { createTool } from '@mastra/core/tools';
import type { EmbeddingModel } from 'ai';
import { z } from 'zod';
import { rerank } from '../rerank';
import type { RerankConfig } from '../rerank';
import {
vectorQuerySearch,
defaultVectorQueryDescription,
filterDescription,
topKDescription,
queryTextDescription,
} from '../utils';
export const createVectorQueryTool = ({
vectorStoreName,
indexName,
model,
enableFilter = false,
reranker,
id,
description,
}: {
vectorStoreName: string;
indexName: string;
model: EmbeddingModel<string>;
enableFilter?: boolean;
reranker?: RerankConfig;
id?: string;
description?: string;
}): ReturnType<typeof createTool> => {
const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
const toolDescription = description || defaultVectorQueryDescription();
// Create base schema with required fields
const baseSchema = {
queryText: z.string().describe(queryTextDescription),
topK: z.coerce.number().describe(topKDescription),
};
const inputSchema = enableFilter
? z
.object({
...baseSchema,
filter: z.coerce.string().describe(filterDescription),
})
.passthrough()
: z.object(baseSchema).passthrough();
return createTool({
id: toolId,
inputSchema,
outputSchema: z.object({
relevantContext: z.any(),
}),
description: toolDescription,
execute: async ({ context: { queryText, topK, filter }, mastra }) => {
const topKValue =
typeof topK === 'number' && !isNaN(topK)
? topK
: typeof topK === 'string' && !isNaN(Number(topK))
? Number(topK)
: 10;
const vectorStore = mastra?.getVector(vectorStoreName);
const logger = mastra?.getLogger();
// Get relevant chunks from the vector database
if (vectorStore) {
let queryFilter = {};
if (enableFilter && filter) {
queryFilter = (() => {
try {
return typeof filter === 'string' ? JSON.parse(filter) : filter;
} catch (error) {
// Log the error and use empty object
if (logger) {
logger.warn('Failed to parse filter as JSON, using empty filter', { filter, error });
}
return {};
}
})();
}
if (logger) {
logger.debug('Using this filter and topK:', { queryFilter, topK: topKValue });
}
const { results } = await vectorQuerySearch({
indexName,
vectorStore,
queryText,
model,
queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
topK: topKValue,
});
if (reranker) {
const rerankedResults = await rerank(results, queryText, reranker.model, {
...reranker.options,
topK: reranker.options?.topK || topKValue,
});
const relevantChunks = rerankedResults.map(({ result }) => result?.metadata);
return { relevantContext: relevantChunks };
}
const relevantChunks = results.map(result => result?.metadata);
return {
relevantContext: relevantChunks,
};
}
return {
relevantContext: [],
};
},
});
};