UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

341 lines (297 loc) 10.1 kB
import { count, sql } from 'drizzle-orm'; import { and, desc, eq, gt, ilike, inArray, isNull } from 'drizzle-orm/expressions'; import { LobeChatDatabase } from '@/database/type'; import { genEndDateWhere, genRangeWhere, genStartDateWhere, genWhere, } from '@/database/utils/genWhere'; import { idGenerator } from '@/database/utils/idGenerator'; import { MessageItem } from '@/types/message'; import { TopicRankItem } from '@/types/topic'; import { TopicItem, messages, topics } from '../schemas'; export interface CreateTopicParams { favorite?: boolean; messages?: string[]; sessionId?: string | null; title: string; } interface QueryTopicParams { current?: number; pageSize?: number; sessionId?: string | null; } export class TopicModel { private userId: string; private db: LobeChatDatabase; constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; this.db = db; } // **************** Query *************** // query = async ({ current = 0, pageSize = 9999, sessionId }: QueryTopicParams = {}) => { const offset = current * pageSize; return ( this.db .select({ createdAt: topics.createdAt, favorite: topics.favorite, historySummary: topics.historySummary, id: topics.id, metadata: topics.metadata, title: topics.title, updatedAt: topics.updatedAt, }) .from(topics) .where(and(eq(topics.userId, this.userId), this.matchSession(sessionId))) // In boolean sorting, false is considered "smaller" than true. // So here we use desc to ensure that topics with favorite as true are in front. .orderBy(desc(topics.favorite), desc(topics.updatedAt)) .limit(pageSize) .offset(offset) ); }; findById = async (id: string) => { return this.db.query.topics.findFirst({ where: and(eq(topics.id, id), eq(topics.userId, this.userId)), }); }; queryAll = async (): Promise<TopicItem[]> => { return this.db .select() .from(topics) .orderBy(topics.updatedAt) .where(eq(topics.userId, this.userId)); }; queryByKeyword = async (keyword: string, sessionId?: string | null): Promise<TopicItem[]> => { if (!keyword) return []; const keywordLowerCase = keyword.toLowerCase(); // 查询标题匹配的主题 const topicsByTitle = await this.db.query.topics.findMany({ orderBy: [desc(topics.updatedAt)], where: and( eq(topics.userId, this.userId), this.matchSession(sessionId), ilike(topics.title, `%${keywordLowerCase}%`), ), }); // 查询消息内容匹配的主题ID const topicIdsByMessages = await this.db .select({ topicId: messages.topicId }) .from(messages) .innerJoin(topics, eq(messages.topicId, topics.id)) .where( and( eq(messages.userId, this.userId), ilike(messages.content, `%${keywordLowerCase}%`), eq(topics.userId, this.userId), this.matchSession(sessionId), ), ) .groupBy(messages.topicId); // 如果没有通过消息内容找到主题,直接返回标题匹配的主题 if (topicIdsByMessages.length === 0) { return topicsByTitle; } // 查询通过消息内容找到的主题 const topicIds = topicIdsByMessages.map((t) => t.topicId); const topicsByMessages = await this.db.query.topics.findMany({ orderBy: [desc(topics.updatedAt)], where: and(eq(topics.userId, this.userId), inArray(topics.id, topicIds)), }); // 合并结果并去重 const allTopics = [...topicsByTitle]; const existingIds = new Set(topicsByTitle.map((t) => t.id)); for (const topic of topicsByMessages) { if (!existingIds.has(topic.id)) { allTopics.push(topic); } } // 按更新时间排序 return allTopics.sort( (a, b) => new Date(b.updatedAt).getTime() - new Date(a.updatedAt).getTime(), ); }; count = async (params?: { endDate?: string; range?: [string, string]; startDate?: string; }): Promise<number> => { const result = await this.db .select({ count: count(topics.id), }) .from(topics) .where( genWhere([ eq(topics.userId, this.userId), params?.range ? genRangeWhere(params.range, topics.createdAt, (date) => date.toDate()) : undefined, params?.endDate ? genEndDateWhere(params.endDate, topics.createdAt, (date) => date.toDate()) : undefined, params?.startDate ? genStartDateWhere(params.startDate, topics.createdAt, (date) => date.toDate()) : undefined, ]), ); return result[0].count; }; rank = async (limit: number = 10): Promise<TopicRankItem[]> => { return this.db .select({ count: count(messages.id).as('count'), id: topics.id, sessionId: topics.sessionId, title: topics.title, }) .from(topics) .where(and(eq(topics.userId, this.userId))) .leftJoin(messages, eq(topics.id, messages.topicId)) .groupBy(topics.id) .orderBy(desc(sql`count`)) .having(({ count }) => gt(count, 0)) .limit(limit); }; // **************** Create *************** // create = async ( { messages: messageIds, ...params }: CreateTopicParams, id: string = this.genId(), ): Promise<TopicItem> => { return this.db.transaction(async (tx) => { // 在 topics 表中插入新的 topic const [topic] = await tx .insert(topics) .values({ ...params, id: id, userId: this.userId, }) .returning(); // 如果有关联的 messages, 更新它们的 topicId if (messageIds && messageIds.length > 0) { await tx .update(messages) .set({ topicId: topic.id }) .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))); } return topic; }); }; batchCreate = async (topicParams: (CreateTopicParams & { id?: string })[]) => { // 开始一个事务 return this.db.transaction(async (tx) => { // 在 topics 表中批量插入新的 topics const createdTopics = await tx .insert(topics) .values( topicParams.map((params) => ({ favorite: params.favorite, id: params.id || this.genId(), sessionId: params.sessionId, title: params.title, userId: this.userId, })), ) .returning(); // 对每个新创建的 topic,更新关联的 messages 的 topicId await Promise.all( createdTopics.map(async (topic, index) => { const messageIds = topicParams[index].messages; if (messageIds && messageIds.length > 0) { await tx .update(messages) .set({ topicId: topic.id }) .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))); } }), ); return createdTopics; }); }; duplicate = async (topicId: string, newTitle?: string) => { return this.db.transaction(async (tx) => { // find original topic const originalTopic = await tx.query.topics.findFirst({ where: and(eq(topics.id, topicId), eq(topics.userId, this.userId)), }); if (!originalTopic) { throw new Error(`Topic with id ${topicId} not found`); } // copy topic const [duplicatedTopic] = await tx .insert(topics) .values({ ...originalTopic, clientId: null, id: this.genId(), title: newTitle || originalTopic?.title, }) .returning(); // 查找与原始 topic 关联的 messages const originalMessages = await tx .select() .from(messages) .where(and(eq(messages.topicId, topicId), eq(messages.userId, this.userId))); // copy messages const duplicatedMessages = await Promise.all( originalMessages.map(async (message) => { const result = (await tx .insert(messages) .values({ ...message, clientId: null, id: idGenerator('messages'), topicId: duplicatedTopic.id, }) .returning()) as MessageItem[]; return result[0]; }), ); return { messages: duplicatedMessages, topic: duplicatedTopic, }; }); }; // **************** Delete *************** // /** * Delete a session, also delete all messages and topics associated with it. */ delete = async (id: string) => { return this.db.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId))); }; /** * Deletes multiple topics based on the sessionId. */ batchDeleteBySessionId = async (sessionId?: string | null) => { return this.db .delete(topics) .where(and(this.matchSession(sessionId), eq(topics.userId, this.userId))); }; /** * Deletes multiple topics and all messages associated with them in a transaction. */ batchDelete = async (ids: string[]) => { return this.db .delete(topics) .where(and(inArray(topics.id, ids), eq(topics.userId, this.userId))); }; deleteAll = async () => { return this.db.delete(topics).where(eq(topics.userId, this.userId)); }; // **************** Update *************** // update = async (id: string, data: Partial<TopicItem>) => { return this.db .update(topics) .set({ ...data, updatedAt: new Date() }) .where(and(eq(topics.id, id), eq(topics.userId, this.userId))) .returning(); }; // **************** Helper *************** // private genId = () => idGenerator('topics'); private matchSession = (sessionId?: string | null) => sessionId ? eq(topics.sessionId, sessionId) : isNull(topics.sessionId); }