UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

761 lines (648 loc) 25.5 kB
import { eq, inArray } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { LobeChatDatabase } from '@/database/type'; import { messages, sessions, topics, users } from '../../schemas'; import { CreateTopicParams, TopicModel } from '../topic'; import { getTestDB } from './_util'; const serverDB: LobeChatDatabase = await getTestDB(); const userId = 'topic-user-test'; const sessionId = 'topic-session'; const topicModel = new TopicModel(serverDB, userId); describe('TopicModel', () => { beforeEach(async () => { await serverDB.delete(users); // 创建测试数据 await serverDB.transaction(async (tx) => { await tx.insert(users).values({ id: userId }); await tx.insert(sessions).values({ id: sessionId, userId }); }); }); afterEach(async () => { // 在每个测试用例之后,清空表 await serverDB.delete(users); }); describe('query', () => { it('should query topics by user ID', async () => { // 创建一些测试数据 await serverDB.transaction(async (tx) => { await tx.insert(users).values([{ id: '456' }]); await tx.insert(topics).values([ { id: '1', userId, sessionId, updatedAt: new Date('2023-01-01') }, { id: '4', userId, sessionId, updatedAt: new Date('2023-03-01') }, { id: '2', userId, sessionId, updatedAt: new Date('2023-02-01'), favorite: true }, { id: '5', userId, sessionId, updatedAt: new Date('2023-05-01'), favorite: true }, { id: '3', userId: '456', sessionId, updatedAt: new Date('2023-03-01') }, ]); }); // 调用 query 方法 const result = await topicModel.query({ sessionId }); // 断言结果 expect(result).toHaveLength(4); expect(result[0].id).toBe('5'); // favorite 的 topic 应该在前面,按照 updatedAt 降序排序 expect(result[1].id).toBe('2'); expect(result[2].id).toBe('4'); // 按照 updatedAt 降序排序 }); it('should query topics with pagination', async () => { // 创建测试数据 await serverDB.insert(topics).values([ { id: '1', userId, updatedAt: new Date('2023-01-01') }, { id: '2', userId, updatedAt: new Date('2023-02-01') }, { id: '3', userId, updatedAt: new Date('2023-03-01') }, ]); // 应该返回 2 个 topics const result1 = await topicModel.query({ current: 0, pageSize: 2 }); expect(result1).toHaveLength(2); // 应该只返回 1 个 topic,并且是第 2 个 const result2 = await topicModel.query({ current: 1, pageSize: 1 }); expect(result2).toHaveLength(1); expect(result2[0].id).toBe('2'); }); it('should query topics by session ID', async () => { // 创建测试数据 await serverDB.transaction(async (tx) => { await tx.insert(sessions).values([ { id: 'session1', userId }, { id: 'session2', userId }, ]); await tx.insert(topics).values([ { id: '1', userId, sessionId: 'session1' }, { id: '2', userId, sessionId: 'session2' }, { id: '3', userId }, // 没有 sessionId ]); }); // 应该只返回属于 session1 的 topic const result = await topicModel.query({ sessionId: 'session1' }); expect(result).toHaveLength(1); expect(result[0].id).toBe('1'); }); it('should return topics based on pagination parameters', async () => { // 创建测试数据 await serverDB.insert(topics).values([ { id: 'topic1', sessionId, userId, updatedAt: new Date('2023-01-01') }, { id: 'topic2', sessionId, userId, updatedAt: new Date('2023-01-02') }, { id: 'topic3', sessionId, userId, updatedAt: new Date('2023-01-03') }, ]); // 调用 query 方法 const result1 = await topicModel.query({ current: 0, pageSize: 2, sessionId }); const result2 = await topicModel.query({ current: 1, pageSize: 2, sessionId }); // 断言返回结果符合分页要求 expect(result1).toHaveLength(2); expect(result1[0].id).toBe('topic3'); expect(result1[1].id).toBe('topic2'); expect(result2).toHaveLength(1); expect(result2[0].id).toBe('topic1'); }); }); describe('findById', () => { it('should return a topic by id', async () => { // 创建测试数据 await serverDB.insert(topics).values({ id: 'topic1', sessionId, userId }); // 调用 findById 方法 const result = await topicModel.findById('topic1'); // 断言返回结果符合预期 expect(result?.id).toBe('topic1'); }); it('should return undefined for non-existent topic', async () => { // 调用 findById 方法 const result = await topicModel.findById('non-existent'); // 断言返回 undefined expect(result).toBeUndefined(); }); }); describe('queryAll', () => { it('should return all topics', async () => { // 创建测试数据 await serverDB.insert(topics).values([ { id: 'topic1', sessionId, userId }, { id: 'topic2', sessionId, userId }, ]); // 调用 queryAll 方法 const result = await topicModel.queryAll(); // 断言返回所有的 topics expect(result).toHaveLength(2); expect(result[0].id).toBe('topic1'); expect(result[1].id).toBe('topic2'); }); }); describe('queryByKeyword', () => { it('should return topics matching topic title keyword', async () => { // 创建测试数据 await serverDB.transaction(async (tx) => { await tx.insert(topics).values([ { id: 'topic1', title: 'Hello world', sessionId, userId }, { id: 'topic2', title: 'Goodbye', sessionId, userId }, ]); await tx .insert(messages) .values([ { id: 'message1', role: 'assistant', content: 'abc there', topicId: 'topic1', userId }, ]); }); // 调用 queryByKeyword 方法 const result = await topicModel.queryByKeyword('hello', sessionId); // 断言返回匹配关键字的 topic expect(result).toHaveLength(1); expect(result[0].id).toBe('topic1'); }); it('should return topics matching message content keyword', async () => { // 创建测试数据 await serverDB.transaction(async (tx) => { await tx.insert(topics).values([ { id: 'topic1', title: 'abc world', sessionId, userId }, { id: 'topic2', title: 'Goodbye', sessionId, userId }, ]); await tx.insert(messages).values([ { id: 'message1', role: 'assistant', content: 'Hello there', topicId: 'topic1', userId, }, ]); }); // 调用 queryByKeyword 方法 const result = await topicModel.queryByKeyword('hello', sessionId); // 断言返回匹配关键字的 topic expect(result).toHaveLength(1); expect(result[0].id).toBe('topic1'); }); it('should return nothing if not match', async () => { // 创建测试数据 await serverDB.insert(topics).values([ { id: 'topic1', title: 'Hello world', userId }, { id: 'topic2', title: 'Goodbye', sessionId, userId }, ]); await serverDB .insert(messages) .values([ { id: 'message1', role: 'assistant', content: 'abc there', topicId: 'topic1', userId }, ]); // 调用 queryByKeyword 方法 const result = await topicModel.queryByKeyword('hello', sessionId); // 断言返回匹配关键字的 topic expect(result).toHaveLength(0); }); }); describe('count', () => { it('should return total number of topics', async () => { // 创建测试数据 await serverDB.insert(topics).values([ { id: 'abc_topic1', sessionId, userId }, { id: 'abc_topic2', sessionId, userId }, ]); // 调用 count 方法 const result = await topicModel.count(); // 断言返回 topics 总数 expect(result).toBe(2); }); }); describe('delete', () => { it('should delete a topic and its associated messages', async () => { const topicId = 'topic1'; await serverDB.transaction(async (tx) => { await tx.insert(users).values({ id: '345' }); await tx.insert(sessions).values([ { id: 'session1', userId }, { id: 'session2', userId: '345' }, ]); await tx.insert(topics).values([ { id: topicId, sessionId: 'session1', userId }, { id: 'topic2', sessionId: 'session2', userId: '345' }, ]); await tx.insert(messages).values([ { id: 'message1', role: 'user', topicId: topicId, userId }, { id: 'message2', role: 'assistant', topicId: topicId, userId }, { id: 'message3', role: 'user', topicId: 'topic2', userId: '345' }, ]); }); // 调用 delete 方法 await topicModel.delete(topicId); // 断言 topic 和关联的 messages 都被删除了 expect( await serverDB.select().from(messages).where(eq(messages.topicId, topicId)), ).toHaveLength(0); expect(await serverDB.select().from(topics)).toHaveLength(1); expect(await serverDB.select().from(messages)).toHaveLength(1); }); }); describe('batchDeleteBySessionId', () => { it('should delete all topics associated with a session', async () => { await serverDB.insert(sessions).values([ { id: 'session1', userId }, { id: 'session2', userId }, ]); await serverDB.insert(topics).values([ { id: 'topic1', sessionId: 'session1', userId }, { id: 'topic2', sessionId: 'session1', userId }, { id: 'topic3', sessionId: 'session2', userId }, { id: 'topic4', userId }, ]); // 调用 batchDeleteBySessionId 方法 await topicModel.batchDeleteBySessionId('session1'); // 断言属于 session1 的 topics 都被删除了 expect( await serverDB.select().from(topics).where(eq(topics.sessionId, 'session1')), ).toHaveLength(0); expect(await serverDB.select().from(topics)).toHaveLength(2); }); it('should delete all topics associated without sessionId', async () => { await serverDB.insert(sessions).values([{ id: 'session1', userId }]); await serverDB.insert(topics).values([ { id: 'topic1', sessionId: 'session1', userId }, { id: 'topic2', sessionId: 'session1', userId }, { id: 'topic4', userId }, ]); // 调用 batchDeleteBySessionId 方法 await topicModel.batchDeleteBySessionId(); // 断言属于 session1 的 topics 都被删除了 expect( await serverDB.select().from(topics).where(eq(topics.sessionId, 'session1')), ).toHaveLength(2); expect(await serverDB.select().from(topics)).toHaveLength(2); }); }); describe('batchDelete', () => { it('should delete multiple topics and their associated messages', async () => { await serverDB.transaction(async (tx) => { await tx.insert(sessions).values({ id: 'session1', userId }); await tx.insert(topics).values([ { id: 'topic1', sessionId: 'session1', userId }, { id: 'topic2', sessionId: 'session1', userId }, { id: 'topic3', sessionId: 'session1', userId }, ]); await tx.insert(messages).values([ { id: 'message1', role: 'user', topicId: 'topic1', userId }, { id: 'message2', role: 'assistant', topicId: 'topic2', userId }, { id: 'message3', role: 'user', topicId: 'topic3', userId }, ]); }); // 调用 batchDelete 方法 await topicModel.batchDelete(['topic1', 'topic2']); // 断言指定的 topics 和关联的 messages 都被删除了 expect(await serverDB.select().from(topics)).toHaveLength(1); expect(await serverDB.select().from(messages)).toHaveLength(1); }); }); describe('deleteAll', () => { it('should delete all topics of the user', async () => { await serverDB.insert(users).values({ id: '345' }); await serverDB.insert(sessions).values([ { id: 'session1', userId }, { id: 'session2', userId: '345' }, ]); await serverDB.insert(topics).values([ { id: 'topic1', sessionId: 'session1', userId }, { id: 'topic2', sessionId: 'session1', userId }, { id: 'topic3', sessionId: 'session2', userId: '345' }, ]); // 调用 deleteAll 方法 await topicModel.deleteAll(); // 断言当前用户的所有 topics 都被删除了 expect(await serverDB.select().from(topics).where(eq(topics.userId, userId))).toHaveLength(0); expect(await serverDB.select().from(topics)).toHaveLength(1); }); }); describe('update', () => { it('should update a topic', async () => { // 创建一个测试 session const topicId = '123'; await serverDB.insert(topics).values({ userId, id: topicId, title: 'Test', favorite: true }); // 调用 update 方法更新 session const item = await topicModel.update(topicId, { title: 'Updated Test', favorite: false, }); // 断言更新后的结果 expect(item).toHaveLength(1); expect(item[0].title).toBe('Updated Test'); expect(item[0].favorite).toBeFalsy(); }); it('should not update a topic if user ID does not match', async () => { // 创建一个测试 topic, 但使用不同的 user ID await serverDB.insert(users).values([{ id: '456' }]); const topicId = '123'; await serverDB .insert(topics) .values({ userId: '456', id: topicId, title: 'Test', favorite: true }); // 尝试更新这个 topic , 应该不会有任何更新 const item = await topicModel.update(topicId, { title: 'Updated Test Session', }); expect(item).toHaveLength(0); }); }); describe('create', () => { it('should create a new topic and associate messages', async () => { const topicData = { title: 'New Topic', favorite: true, sessionId, messages: ['message1', 'message2'], } satisfies CreateTopicParams; const topicId = 'new-topic'; // 预先创建一些 messages await serverDB.insert(messages).values([ { id: 'message1', role: 'user', userId, sessionId }, { id: 'message2', role: 'assistant', userId, sessionId }, { id: 'message3', role: 'user', userId, sessionId }, ]); // 调用 create 方法 const createdTopic = await topicModel.create(topicData, topicId); // 断言返回的 topic 数据正确 expect(createdTopic).toEqual({ id: topicId, title: 'New Topic', favorite: true, sessionId, userId, historySummary: null, metadata: null, clientId: null, createdAt: expect.any(Date), updatedAt: expect.any(Date), accessedAt: expect.any(Date), }); // 断言 topic 已在数据库中创建 const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId)); expect(dbTopic).toHaveLength(1); expect(dbTopic[0]).toEqual(createdTopic); // 断言关联的 messages 的 topicId 已更新 const associatedMessages = await serverDB .select() .from(messages) .where(inArray(messages.id, topicData.messages!)); expect(associatedMessages).toHaveLength(2); expect(associatedMessages.every((msg) => msg.topicId === topicId)).toBe(true); // 断言未关联的 message 的 topicId 没有更新 const unassociatedMessage = await serverDB .select() .from(messages) .where(eq(messages.id, 'message3')); expect(unassociatedMessage[0].topicId).toBeNull(); }); it('should create a new topic without associating messages', async () => { const topicData = { title: 'New Topic', favorite: false, sessionId, }; const topicId = 'new-topic'; // 调用 create 方法 const createdTopic = await topicModel.create(topicData, topicId); // 断言返回的 topic 数据正确 expect(createdTopic).toEqual({ id: topicId, title: 'New Topic', favorite: false, clientId: null, historySummary: null, metadata: null, sessionId, userId, createdAt: expect.any(Date), updatedAt: expect.any(Date), accessedAt: expect.any(Date), }); // 断言 topic 已在数据库中创建 const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId)); expect(dbTopic).toHaveLength(1); expect(dbTopic[0]).toEqual(createdTopic); }); }); describe('batchCreate', () => { it('should batch create topics and update associated messages', async () => { // 准备测试数据 const topicParams = [ { title: 'Topic 1', favorite: true, sessionId, messages: ['message1', 'message2'], }, { title: 'Topic 2', favorite: false, sessionId, messages: ['message3'], }, ]; await serverDB.insert(messages).values([ { id: 'message1', role: 'user', userId }, { id: 'message2', role: 'assistant', userId }, { id: 'message3', role: 'user', userId }, ]); // 调用 batchCreate 方法 const createdTopics = await topicModel.batchCreate(topicParams); // 断言返回的 topics 数据正确 expect(createdTopics).toHaveLength(2); expect(createdTopics[0]).toMatchObject({ title: 'Topic 1', favorite: true, sessionId, userId, }); expect(createdTopics[1]).toMatchObject({ title: 'Topic 2', favorite: false, sessionId, userId, }); // 断言 topics 表中的数据正确 const items = await serverDB.select().from(topics); expect(items).toHaveLength(2); expect(items[0]).toMatchObject({ title: 'Topic 1', favorite: true, sessionId, userId, }); expect(items[1]).toMatchObject({ title: 'Topic 2', favorite: false, sessionId, userId, }); // 断言关联的 messages 的 topicId 被正确更新 const updatedMessages = await serverDB.select().from(messages); expect(updatedMessages).toHaveLength(3); expect(updatedMessages[0].topicId).toBe(createdTopics[0].id); expect(updatedMessages[1].topicId).toBe(createdTopics[0].id); expect(updatedMessages[2].topicId).toBe(createdTopics[1].id); }); it('should generate topic IDs if not provided', async () => { // 准备测试数据 const topicParams = [ { title: 'Topic 1', favorite: true, sessionId, }, { title: 'Topic 2', favorite: false, sessionId, }, ]; // 调用 batchCreate 方法 const createdTopics = await topicModel.batchCreate(topicParams); // 断言生成了正确的 topic ID expect(createdTopics[0].id).toBeDefined(); expect(createdTopics[1].id).toBeDefined(); expect(createdTopics[0].id).not.toBe(createdTopics[1].id); }); }); describe('duplicate', () => { it('should duplicate a topic and its associated messages', async () => { const topicId = 'topic-duplicate'; const newTitle = 'Duplicated Topic'; // 创建原始的 topic 和 messages await serverDB.transaction(async (tx) => { await tx.insert(topics).values({ id: topicId, sessionId, userId, title: 'Original Topic' }); await tx.insert(messages).values([ { id: 'message1', role: 'user', topicId, userId, content: 'User message' }, { id: 'message2', role: 'assistant', topicId, userId, content: 'Assistant message' }, ]); }); // 调用 duplicate 方法 const { topic: duplicatedTopic, messages: duplicatedMessages } = await topicModel.duplicate( topicId, newTitle, ); // 断言复制的 topic 的属性正确 expect(duplicatedTopic.id).not.toBe(topicId); expect(duplicatedTopic.title).toBe(newTitle); expect(duplicatedTopic.sessionId).toBe(sessionId); expect(duplicatedTopic.userId).toBe(userId); // 断言复制的 messages 的属性正确 expect(duplicatedMessages).toHaveLength(2); expect(duplicatedMessages[0].id).not.toBe('message1'); expect(duplicatedMessages[0].topicId).toBe(duplicatedTopic.id); expect(duplicatedMessages[0].content).toBe('User message'); expect(duplicatedMessages[1].id).not.toBe('message2'); expect(duplicatedMessages[1].topicId).toBe(duplicatedTopic.id); expect(duplicatedMessages[1].content).toBe('Assistant message'); }); it('should throw an error if the topic to duplicate does not exist', async () => { const topicId = 'nonexistent-topic'; // 调用 duplicate 方法,期望抛出错误 await expect(topicModel.duplicate(topicId)).rejects.toThrow( `Topic with id ${topicId} not found`, ); }); }); describe('rank', () => { it('should return ranked topics based on message count', async () => { // 创建测试数据 await serverDB.transaction(async (tx) => { await tx.insert(topics).values([ { id: 'topic1', title: 'Topic 1', sessionId, userId }, { id: 'topic2', title: 'Topic 2', sessionId, userId }, { id: 'topic3', title: 'Topic 3', sessionId, userId }, ]); // topic1 有 3 条消息 await tx.insert(messages).values([ { id: 'msg1', role: 'user', topicId: 'topic1', userId }, { id: 'msg2', role: 'assistant', topicId: 'topic1', userId }, { id: 'msg3', role: 'user', topicId: 'topic1', userId }, ]); // topic2 有 2 条消息 await tx.insert(messages).values([ { id: 'msg4', role: 'user', topicId: 'topic2', userId }, { id: 'msg5', role: 'assistant', topicId: 'topic2', userId }, ]); // topic3 有 1 条消息 await tx.insert(messages).values([{ id: 'msg6', role: 'user', topicId: 'topic3', userId }]); }); // 调用 rank 方法 const result = await topicModel.rank(2); // 断言返回结果符合预期 expect(result).toHaveLength(2); expect(result[0]).toMatchObject({ id: 'topic1', title: 'Topic 1', count: 3, sessionId, }); expect(result[1]).toMatchObject({ id: 'topic2', title: 'Topic 2', count: 2, sessionId, }); }); it('should return empty array if no topics exist', async () => { const result = await topicModel.rank(); expect(result).toHaveLength(0); }); it('should respect the limit parameter', async () => { // 创建测试数据 await serverDB.transaction(async (tx) => { await tx.insert(topics).values([ { id: 'topic1', title: 'Topic 1', sessionId, userId }, { id: 'topic2', title: 'Topic 2', sessionId, userId }, ]); await tx.insert(messages).values([ { id: 'msg1', role: 'user', topicId: 'topic1', userId }, { id: 'msg2', role: 'user', topicId: 'topic2', userId }, ]); }); // 使用限制为 1 调用 rank 方法 const result = await topicModel.rank(1); // 断言只返回一个结果 expect(result).toHaveLength(1); }); }); describe('count with date filters', () => { beforeEach(async () => { // 创建测试数据 await serverDB.insert(topics).values([ { id: 'topic1', userId, createdAt: new Date('2023-01-01'), }, { id: 'topic2', userId, createdAt: new Date('2023-02-01'), }, { id: 'topic3', userId, createdAt: new Date('2023-03-01'), }, ]); }); it('should count topics with start date filter', async () => { const result = await topicModel.count({ startDate: '2023-02-01', }); expect(result).toBe(2); // should count topics from Feb 1st onwards }); it('should count topics with end date filter', async () => { const result = await topicModel.count({ endDate: '2023-02-01', }); expect(result).toBe(2); // should count topics up to Feb 1st }); it('should count topics within date range', async () => { const result = await topicModel.count({ range: ['2023-01-15', '2023-02-15'], }); expect(result).toBe(1); // should only count topic2 }); it('should return 0 if no topics match date filters', async () => { const result = await topicModel.count({ range: ['2024-01-01', '2024-12-31'], }); expect(result).toBe(0); }); it('should handle invalid date filters gracefully', async () => { const result = await topicModel.count({ startDate: 'invalid-date', }); expect(result).toBe(3); // should return all topics if date is invalid }); }); });