@langchain/community
Version:
Third-party integrations for LangChain.js
60 lines (59 loc) • 1.86 kB
JavaScript
/* eslint-disable @typescript-eslint/no-explicit-any */
import { FakeEmbeddings } from "@langchain/core/utils/testing";
import { jest, test, expect } from "@jest/globals";
import { PrismaVectorStore } from "../prisma.js";
class Sql {
constructor() {
Object.defineProperty(this, "strings", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
}
}
const mockColumns = {
id: PrismaVectorStore.IdColumn,
content: PrismaVectorStore.ContentColumn,
};
const sql = jest.fn();
const raw = jest.fn();
const join = jest.fn();
const mockPrismaNamespace = {
ModelName: {},
Sql,
raw,
join,
sql,
};
const $queryRaw = jest.fn();
const $executeRaw = jest.fn();
const $transaction = jest.fn();
const mockPrismaClient = {
$queryRaw,
$executeRaw,
$transaction,
};
describe("Prisma", () => {
beforeEach(() => {
jest.clearAllMocks();
});
test("passes provided filters with simiaritySearch", async () => {
const embeddings = new FakeEmbeddings();
const store = new PrismaVectorStore(new FakeEmbeddings(), {
db: mockPrismaClient,
prisma: mockPrismaNamespace,
tableName: "test",
vectorColumnName: "vector",
columns: mockColumns,
});
const similaritySearchVectorWithScoreSpy = jest
.spyOn(store, "similaritySearchVectorWithScore")
.mockResolvedValue([]);
const filter = { id: { equals: "123" } };
await store.similaritySearch("hello", 1, filter);
const embeddedQuery = await embeddings.embedQuery("hello");
expect(similaritySearchVectorWithScoreSpy).toHaveBeenCalledTimes(1);
expect(similaritySearchVectorWithScoreSpy).toHaveBeenCalledWith(embeddedQuery, 1, filter);
});
});