UNPKG

@caleblawson/rag

Version:

The Retrieval-Augmented Generation (RAG) module contains document processing and embedding utilities.

330 lines (298 loc) 9.79 kB
import { RuntimeContext } from '@mastra/core/runtime-context'; import { describe, it, expect, vi, beforeEach } from 'vitest'; import { rerank } from '../rerank'; import { vectorQuerySearch } from '../utils'; import { createVectorQueryTool } from './vector-query'; vi.mock('../utils', async importOriginal => { const actual: any = await importOriginal(); return { ...actual, vectorQuerySearch: vi.fn().mockResolvedValue({ results: [{ metadata: { text: 'foo' }, vector: [1, 2, 3] }] }), }; }); vi.mock('../rerank', async importOriginal => { const actual: any = await importOriginal(); return { ...actual, rerank: vi .fn() .mockResolvedValue([ { result: { id: '1', metadata: { text: 'bar' }, score: 1, details: { semantic: 1, vector: 1, position: 1 } } }, ]), }; }); describe('createVectorQueryTool', () => { const mockModel = { name: 'test-model' } as any; const mockMastra = { vectors: { testStore: { // Mock vector store methods }, anotherStore: { // Mock vector store methods }, }, getVector: vi.fn(storeName => ({ [storeName]: { // Mock vector store methods }, })), logger: { debug: vi.fn(), warn: vi.fn(), info: vi.fn(), }, getLogger: vi.fn(() => ({ debug: vi.fn(), warn: vi.fn(), info: vi.fn(), })), }; beforeEach(() => { vi.clearAllMocks(); }); describe('input schema validation', () => { it('should handle filter permissively when enableFilter is false', () => { // Create tool with enableFilter set to false const tool = createVectorQueryTool({ vectorStoreName: 'testStore', indexName: 'testIndex', model: mockModel, enableFilter: false, }); // Get the Zod schema const schema = tool.inputSchema; // Test with no filter (should be valid) const validInput = { queryText: 'test query', topK: 5, }; expect(() => schema?.parse(validInput)).not.toThrow(); // Test with filter (should throw - unexpected property) const inputWithFilter = { ...validInput, filter: '{"field": "value"}', }; expect(() => schema?.parse(inputWithFilter)).not.toThrow(); }); it('should handle filter when enableFilter is true', () => { const tool = createVectorQueryTool({ vectorStoreName: 'testStore', indexName: 'testIndex', model: mockModel, enableFilter: true, }); // Get the Zod schema const schema = tool.inputSchema; // Test various filter inputs that should coerce to string const testCases = [ // String inputs { filter: '{"field": "value"}' }, { filter: '{}' }, { filter: 'simple-string' }, // Empty { filter: '' }, { filter: { field: 'value' } }, { filter: {} }, { filter: 123 }, { filter: null }, { filter: undefined }, ]; testCases.forEach(({ filter }) => { expect(() => schema?.parse({ queryText: 'test query', topK: 5, filter, }), ).not.toThrow(); }); // Verify that all parsed values are strings testCases.forEach(({ filter }) => { const result = schema?.parse({ queryText: 'test query', topK: 5, filter, }); expect(typeof result?.filter).toBe('string'); }); }); it('should not reject unexpected properties in both modes', () => { // Test with enableFilter false const toolWithoutFilter = createVectorQueryTool({ vectorStoreName: 'testStore', indexName: 'testIndex', model: mockModel, enableFilter: false, }); // Should reject unexpected property expect(() => toolWithoutFilter.inputSchema?.parse({ queryText: 'test query', topK: 5, unexpectedProp: 'value', }), ).not.toThrow(); // Test with enableFilter true const toolWithFilter = createVectorQueryTool({ vectorStoreName: 'testStore', indexName: 'testIndex', model: mockModel, enableFilter: true, }); // Should reject unexpected property even with valid filter expect(() => toolWithFilter.inputSchema?.parse({ queryText: 'test query', topK: 5, filter: '{}', unexpectedProp: 'value', }), ).not.toThrow(); }); }); describe('execute function', () => { it('should not process filter when enableFilter is false', async () => { const runtimeContext = new RuntimeContext(); // Create tool with enableFilter set to false const tool = createVectorQueryTool({ vectorStoreName: 'testStore', indexName: 'testIndex', model: mockModel, enableFilter: false, }); // Execute with no filter await tool.execute?.({ context: { queryText: 'test query', topK: 5, }, mastra: mockMastra as any, runtimeContext, }); // Check that vectorQuerySearch was called with undefined queryFilter expect(vectorQuerySearch).toHaveBeenCalledWith( expect.objectContaining({ queryFilter: undefined, }), ); }); it('should process filter when enableFilter is true and filter is provided', async () => { const runtimeContext = new RuntimeContext(); // Create tool with enableFilter set to true const tool = createVectorQueryTool({ vectorStoreName: 'testStore', indexName: 'testIndex', model: mockModel, enableFilter: true, }); const filterJson = '{"field": "value"}'; // Execute with filter await tool.execute?.({ context: { queryText: 'test query', topK: 5, filter: filterJson, }, mastra: mockMastra as any, runtimeContext, }); // Check that vectorQuerySearch was called with the parsed filter expect(vectorQuerySearch).toHaveBeenCalledWith( expect.objectContaining({ queryFilter: { field: 'value' }, }), ); }); it('should handle string filters correctly', async () => { const runtimeContext = new RuntimeContext(); // Create tool with enableFilter set to true const tool = createVectorQueryTool({ vectorStoreName: 'testStore', indexName: 'testIndex', model: mockModel, enableFilter: true, }); const stringFilter = 'string-filter'; // Execute with string filter await tool.execute?.({ context: { queryText: 'test query', topK: 5, filter: stringFilter, }, mastra: mockMastra as any, runtimeContext, }); // Since this is not a valid filter, it should be ignored expect(vectorQuerySearch).toHaveBeenCalledWith( expect.objectContaining({ queryFilter: undefined, }), ); }); }); describe('runtimeContext', () => { it('calls vectorQuerySearch with runtimeContext params', async () => { const tool = createVectorQueryTool({ id: 'test', model: mockModel, indexName: 'testIndex', vectorStoreName: 'testStore', }); const runtimeContext = new RuntimeContext(); runtimeContext.set('indexName', 'anotherIndex'); runtimeContext.set('vectorStoreName', 'anotherStore'); runtimeContext.set('topK', 3); runtimeContext.set('filter', { foo: 'bar' }); runtimeContext.set('includeVectors', true); runtimeContext.set('includeSources', false); const result = await tool.execute({ context: { queryText: 'foo', topK: 6 }, mastra: mockMastra as any, runtimeContext, }); expect(result.relevantContext.length).toBeGreaterThan(0); expect(result.sources).toEqual([]); // includeSources false expect(vectorQuerySearch).toHaveBeenCalledWith( expect.objectContaining({ indexName: 'anotherIndex', vectorStore: { anotherStore: {}, }, queryText: 'foo', model: mockModel, queryFilter: { foo: 'bar' }, topK: 3, includeVectors: true, }), ); }); it('handles reranker from runtimeContext', async () => { const tool = createVectorQueryTool({ id: 'test', model: mockModel, indexName: 'testIndex', vectorStoreName: 'testStore', }); const runtimeContext = new RuntimeContext(); runtimeContext.set('indexName', 'testIndex'); runtimeContext.set('vectorStoreName', 'testStore'); runtimeContext.set('reranker', { model: 'reranker-model', options: { topK: 1 } }); // Mock rerank vi.mocked(rerank).mockResolvedValue([ { result: { id: '1', metadata: { text: 'bar' }, score: 1 }, score: 1, details: { semantic: 1, vector: 1, position: 1 }, }, ]); const result = await tool.execute({ context: { queryText: 'foo', topK: 1 }, mastra: mockMastra as any, runtimeContext, }); expect(result.relevantContext[0]).toEqual({ text: 'bar' }); }); }); });