UNPKG

@langchain/community

Version:
60 lines (59 loc) 1.86 kB
/* eslint-disable @typescript-eslint/no-explicit-any */ import { FakeEmbeddings } from "@langchain/core/utils/testing"; import { jest, test, expect } from "@jest/globals"; import { PrismaVectorStore } from "../prisma.js"; class Sql { constructor() { Object.defineProperty(this, "strings", { enumerable: true, configurable: true, writable: true, value: void 0 }); } } const mockColumns = { id: PrismaVectorStore.IdColumn, content: PrismaVectorStore.ContentColumn, }; const sql = jest.fn(); const raw = jest.fn(); const join = jest.fn(); const mockPrismaNamespace = { ModelName: {}, Sql, raw, join, sql, }; const $queryRaw = jest.fn(); const $executeRaw = jest.fn(); const $transaction = jest.fn(); const mockPrismaClient = { $queryRaw, $executeRaw, $transaction, }; describe("Prisma", () => { beforeEach(() => { jest.clearAllMocks(); }); test("passes provided filters with simiaritySearch", async () => { const embeddings = new FakeEmbeddings(); const store = new PrismaVectorStore(new FakeEmbeddings(), { db: mockPrismaClient, prisma: mockPrismaNamespace, tableName: "test", vectorColumnName: "vector", columns: mockColumns, }); const similaritySearchVectorWithScoreSpy = jest .spyOn(store, "similaritySearchVectorWithScore") .mockResolvedValue([]); const filter = { id: { equals: "123" } }; await store.similaritySearch("hello", 1, filter); const embeddedQuery = await embeddings.embedQuery("hello"); expect(similaritySearchVectorWithScoreSpy).toHaveBeenCalledTimes(1); expect(similaritySearchVectorWithScoreSpy).toHaveBeenCalledWith(embeddedQuery, 1, filter); }); });