UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

243 lines (204 loc) 6.5 kB
import { cosineDistance, count, sql } from 'drizzle-orm'; import { and, asc, desc, eq, inArray, isNull } from 'drizzle-orm/expressions'; import { chunk } from 'lodash-es'; import { LobeChatDatabase } from '@/database/type'; import { ChunkMetadata, FileChunk } from '@/types/chunk'; import { NewChunkItem, NewUnstructuredChunkItem, chunks, embeddings, fileChunks, files, unstructuredChunks, } from '../schemas'; export class ChunkModel { private userId: string; private db: LobeChatDatabase; constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; this.db = db; } bulkCreate = async (params: NewChunkItem[], fileId: string) => { return this.db.transaction(async (trx) => { if (params.length === 0) return []; const result = await trx.insert(chunks).values(params).returning(); const fileChunksData = result.map((chunk) => ({ chunkId: chunk.id, fileId, userId: this.userId, })); if (fileChunksData.length > 0) { await trx.insert(fileChunks).values(fileChunksData); } return result; }); }; bulkCreateUnstructuredChunks = async (params: NewUnstructuredChunkItem[]) => { return this.db.insert(unstructuredChunks).values(params); }; delete = async (id: string) => { return this.db.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId))); }; deleteOrphanChunks = async () => { const orphanedChunks = await this.db .select({ chunkId: chunks.id }) .from(chunks) .leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) .where(isNull(fileChunks.fileId)); const ids = orphanedChunks.map((chunk) => chunk.chunkId); if (ids.length === 0) return; const list = chunk(ids, 500); await this.db.transaction(async (trx) => { await Promise.all( list.map(async (chunkIds) => { await trx.delete(chunks).where(inArray(chunks.id, chunkIds)); }), ); }); }; findById = async (id: string) => { return this.db.query.chunks.findFirst({ where: and(eq(chunks.id, id)), }); }; findByFileId = async (id: string, page = 0) => { const data = await this.db .select({ abstract: chunks.abstract, createdAt: chunks.createdAt, id: chunks.id, index: chunks.index, metadata: chunks.metadata, text: chunks.text, type: chunks.type, updatedAt: chunks.updatedAt, }) .from(chunks) .innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) .where(and(eq(fileChunks.fileId, id), eq(chunks.userId, this.userId))) .limit(20) .offset(page * 20) .orderBy(asc(chunks.index)); return data.map((item) => { const metadata = item.metadata as ChunkMetadata; return { ...item, metadata, pageNumber: metadata?.pageNumber } as FileChunk; }); }; getChunksTextByFileId = async (id: string): Promise<{ id: string; text: string }[]> => { const data = await this.db .select() .from(chunks) .innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) .where(eq(fileChunks.fileId, id)); return data .map((item) => item.chunks) .map((chunk) => ({ id: chunk.id, text: this.mapChunkText(chunk) })) .filter((chunk) => chunk.text) as { id: string; text: string }[]; }; countByFileIds = async (ids: string[]) => { if (ids.length === 0) return []; return this.db .select({ count: count(fileChunks.chunkId), id: fileChunks.fileId, }) .from(fileChunks) .where(inArray(fileChunks.fileId, ids)) .groupBy(fileChunks.fileId); }; countByFileId = async (ids: string) => { const data = await this.db .select({ count: count(fileChunks.chunkId), id: fileChunks.fileId, }) .from(fileChunks) .where(eq(fileChunks.fileId, ids)) .groupBy(fileChunks.fileId); return data[0]?.count ?? 0; }; semanticSearch = async ({ embedding, fileIds, }: { embedding: number[]; fileIds: string[] | undefined; query: string; }) => { const similarity = sql<number>`1 - (${cosineDistance(embeddings.embeddings, embedding)})`; const data = await this.db .select({ fileId: fileChunks.fileId, fileName: files.name, id: chunks.id, index: chunks.index, metadata: chunks.metadata, similarity, text: chunks.text, type: chunks.type, }) .from(chunks) .leftJoin(embeddings, eq(chunks.id, embeddings.chunkId)) .leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) .leftJoin(files, eq(fileChunks.fileId, files.id)) .where(fileIds ? inArray(fileChunks.fileId, fileIds) : undefined) .orderBy((t) => desc(t.similarity)) .limit(30); return data.map((item) => ({ ...item, metadata: item.metadata as ChunkMetadata, })); }; semanticSearchForChat = async ({ embedding, fileIds, }: { embedding: number[]; fileIds: string[] | undefined; query: string; }) => { const similarity = sql<number>`1 - (${cosineDistance(embeddings.embeddings, embedding)})`; const hasFiles = fileIds && fileIds.length > 0; if (!hasFiles) return []; const result = await this.db .select({ fileId: files.id, fileName: files.name, id: chunks.id, index: chunks.index, metadata: chunks.metadata, similarity, text: chunks.text, type: chunks.type, }) .from(chunks) .leftJoin(embeddings, eq(chunks.id, embeddings.chunkId)) .leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) .leftJoin(files, eq(files.id, fileChunks.fileId)) .where(inArray(fileChunks.fileId, fileIds)) .orderBy((t) => desc(t.similarity)) // 先放宽到 15 .limit(15); return result.map((item) => { return { fileId: item.fileId, fileName: item.fileName, id: item.id, index: item.index, similarity: item.similarity, text: this.mapChunkText(item), }; }); }; private mapChunkText = (chunk: { metadata: any; text: string | null; type: string | null }) => { let text = chunk.text; if (chunk.type === 'Table') { text = `${chunk.text} content in Table html is below: ${(chunk.metadata as ChunkMetadata).text_as_html} `; } return text; }; }