UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

342 lines (283 loc) 9.67 kB
import { count, sum } from 'drizzle-orm'; import { and, asc, desc, eq, ilike, inArray, like, notExists, or } from 'drizzle-orm/expressions'; import type { PgTransaction } from 'drizzle-orm/pg-core'; import { LobeChatDatabase } from '@/database/type'; import { FilesTabs, QueryFileListParams, SortType } from '@/types/files'; import { FileItem, NewFile, NewGlobalFile, chunks, embeddings, fileChunks, files, globalFiles, knowledgeBaseFiles, } from '../../schemas'; export class FileModel { private readonly userId: string; private db: LobeChatDatabase; constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; this.db = db; } create = async ( params: Omit<NewFile, 'id' | 'userId'> & { knowledgeBaseId?: string }, insertToGlobalFiles?: boolean, ) => { const result = await this.db.transaction(async (trx) => { if (insertToGlobalFiles) { await trx.insert(globalFiles).values({ fileType: params.fileType, hashId: params.fileHash!, metadata: params.metadata, size: params.size, url: params.url, }); } const result = await trx .insert(files) .values({ ...params, userId: this.userId }) .returning(); const item = result[0]; if (params.knowledgeBaseId) { await trx .insert(knowledgeBaseFiles) .values({ fileId: item.id, knowledgeBaseId: params.knowledgeBaseId }); } return item; }); return { id: result.id }; }; createGlobalFile = async (file: Omit<NewGlobalFile, 'id' | 'userId'>) => { return this.db.insert(globalFiles).values(file).returning(); }; checkHash = async (hash: string) => { const item = await this.db.query.globalFiles.findFirst({ where: eq(globalFiles.hashId, hash), }); if (!item) return { isExist: false }; return { fileType: item.fileType, isExist: true, metadata: item.metadata, size: item.size, url: item.url, }; }; delete = async (id: string, removeGlobalFile: boolean = true) => { const file = await this.findById(id); if (!file) return; const fileHash = file.fileHash!; return await this.db.transaction(async (trx) => { // 1. 删除相关的 chunks await this.deleteFileChunks(trx as any, [id]); // 2. 删除文件记录 await trx.delete(files).where(and(eq(files.id, id), eq(files.userId, this.userId))); const result = await trx .select({ count: count() }) .from(files) .where(and(eq(files.fileHash, fileHash))); const fileCount = result[0].count; // delete the file from global file if it is not used by other files // if `DISABLE_REMOVE_GLOBAL_FILE` is true, we will not remove the global file if (fileCount === 0 && removeGlobalFile) { await trx.delete(globalFiles).where(eq(globalFiles.hashId, fileHash)); return file; } }); }; deleteGlobalFile = async (hashId: string) => { return this.db.delete(globalFiles).where(eq(globalFiles.hashId, hashId)); }; countUsage = async () => { const result = await this.db .select({ totalSize: sum(files.size), }) .from(files) .where(eq(files.userId, this.userId)); return parseInt(result[0].totalSize!) || 0; }; deleteMany = async (ids: string[], removeGlobalFile: boolean = true) => { const fileList = await this.findByIds(ids); const hashList = fileList.map((file) => file.fileHash!); return await this.db.transaction(async (trx) => { // 1. 删除相关的 chunks await this.deleteFileChunks(trx as any, ids); // delete the files await trx.delete(files).where(and(inArray(files.id, ids), eq(files.userId, this.userId))); // count the files by hash const result = await trx .select({ count: count(), hashId: files.fileHash, }) .from(files) .where(inArray(files.fileHash, hashList)) .groupBy(files.fileHash); // Create a Map to store the query result const countMap = new Map(result.map((item) => [item.hashId, item.count])); // Ensure that all incoming hashes have a result, even if it is 0 const fileHashCounts = hashList.map((hashId) => ({ count: countMap.get(hashId) || 0, hashId: hashId, })); const needToDeleteList = fileHashCounts.filter((item) => item.count === 0); if (needToDeleteList.length === 0 || !removeGlobalFile) return; // delete the file from global file if it is not used by other files await trx.delete(globalFiles).where( inArray( globalFiles.hashId, needToDeleteList.map((item) => item.hashId!), ), ); return fileList.filter((file) => needToDeleteList.some((item) => item.hashId === file.fileHash), ); }); }; clear = async () => { return this.db.delete(files).where(eq(files.userId, this.userId)); }; query = async ({ category, q, sortType, sorter, knowledgeBaseId, showFilesInKnowledgeBase, }: QueryFileListParams = {}) => { // 1. query where let whereClause = and( q ? ilike(files.name, `%${q}%`) : undefined, eq(files.userId, this.userId), ); if (category && category !== FilesTabs.All) { const fileTypePrefix = this.getFileTypePrefix(category as FilesTabs); whereClause = and(whereClause, ilike(files.fileType, `${fileTypePrefix}%`)); } // 2. order part let orderByClause = desc(files.createdAt); // create a map for sortable fields const sortableFields = { createdAt: files.createdAt, name: files.name, size: files.size, updatedAt: files.updatedAt, } as const; type SortableField = keyof typeof sortableFields; if (sorter && sortType && sorter in sortableFields) { const sortFunction = sortType.toLowerCase() === SortType.Asc ? asc : desc; orderByClause = sortFunction(sortableFields[sorter as SortableField]); } // 3. build query let query = this.db .select({ chunkTaskId: files.chunkTaskId, createdAt: files.createdAt, embeddingTaskId: files.embeddingTaskId, fileType: files.fileType, id: files.id, name: files.name, size: files.size, updatedAt: files.updatedAt, url: files.url, }) .from(files); // 4. add knowledge base query if (knowledgeBaseId) { // if knowledgeBaseId is provided, it means we are querying files in a knowledge-base // @ts-ignore query = query.innerJoin( knowledgeBaseFiles, and( eq(files.id, knowledgeBaseFiles.fileId), eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId), ), ); } // 5.if we don't show files in knowledge base, we need exclude files in knowledge base else if (!showFilesInKnowledgeBase) { whereClause = and( whereClause, notExists( this.db.select().from(knowledgeBaseFiles).where(eq(knowledgeBaseFiles.fileId, files.id)), ), ); } // or we are just filter in the global files return query.where(whereClause).orderBy(orderByClause); }; findByIds = async (ids: string[]) => { return this.db.query.files.findMany({ where: and(inArray(files.id, ids), eq(files.userId, this.userId)), }); }; findById = async (id: string) => { return this.db.query.files.findFirst({ where: and(eq(files.id, id), eq(files.userId, this.userId)), }); }; countFilesByHash = async (hash: string) => { const result = await this.db .select({ count: count(), }) .from(files) .where(and(eq(files.fileHash, hash))); return result[0].count; }; update = async (id: string, value: Partial<FileItem>) => this.db .update(files) .set({ ...value, updatedAt: new Date() }) .where(and(eq(files.id, id), eq(files.userId, this.userId))); /** * get the corresponding file type prefix according to FilesTabs */ private getFileTypePrefix = (category: FilesTabs): string => { switch (category) { case FilesTabs.Audios: { return 'audio'; } case FilesTabs.Documents: { return 'application'; } case FilesTabs.Images: { return 'image'; } case FilesTabs.Videos: { return 'video'; } default: { return ''; } } }; findByNames = async (fileNames: string[]) => this.db.query.files.findMany({ where: and( or(...fileNames.map((name) => like(files.name, `${name}%`))), eq(files.userId, this.userId), ), }); // 抽象出通用的删除 chunks 方法 private deleteFileChunks = async (trx: PgTransaction<any>, fileIds: string[]) => { const BATCH_SIZE = 1000; // 每批处理的数量 // 1. 获取所有关联的 chunk IDs const relatedChunks = await trx .select({ chunkId: fileChunks.chunkId }) .from(fileChunks) .where(inArray(fileChunks.fileId, fileIds)); const chunkIds = relatedChunks.map((c) => c.chunkId).filter(Boolean) as string[]; if (chunkIds.length === 0) return; // 2. 分批处理删除 for (let i = 0; i < chunkIds.length; i += BATCH_SIZE) { const batchChunkIds = chunkIds.slice(i, i + BATCH_SIZE); await trx.delete(embeddings).where(inArray(embeddings.chunkId, batchChunkIds)); await trx.delete(chunks).where(inArray(chunks.id, batchChunkIds)); } return chunkIds; }; }