UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

722 lines (650 loc) 20.7 kB
import { and, eq, inArray } from 'drizzle-orm/expressions'; import * as EXPORT_TABLES from '@/database/schemas'; import { LobeChatDatabase } from '@/database/type'; import { ImportPgDataStructure } from '@/types/export'; import { ImportResultData, ImporterEntryData } from '@/types/importer'; import { uuid } from '@/utils/uuid'; import { DeprecatedDataImporterRepos } from './deprecated'; interface ImportResult { added: number; errors: number; skips: number; updated?: number; } type ConflictStrategy = 'skip' | 'override' | 'merge'; interface TableImportConfig { // 冲突处理策略 conflictStrategy?: ConflictStrategy; // 字段处理函数 fieldProcessors?: { [field: string]: (value: any) => any; }; // 是否使用复合主键(没有单独的id字段) isCompositeKey?: boolean; // 是否保留原始ID preserveId?: boolean; // 关系字段定义 relations?: { field: string; sourceField?: string; sourceTable: string; }[]; // 自引用字段 selfReferences?: { field: string; sourceField?: string; }[]; // 表名 table: string; // 唯一约束字段 uniqueConstraints?: string[]; } // 导入表配置 const IMPORT_TABLE_CONFIG: TableImportConfig[] = [ { conflictStrategy: 'merge', preserveId: true, // 特殊表,ID与用户ID相同 table: 'userSettings', uniqueConstraints: ['id'], }, { conflictStrategy: 'merge', isCompositeKey: true, table: 'userInstalledPlugins', uniqueConstraints: ['identifier'], }, { conflictStrategy: 'skip', preserveId: true, table: 'aiProviders', uniqueConstraints: ['id'], }, { conflictStrategy: 'skip', preserveId: true, // 需要保留原始ID relations: [ { field: 'providerId', sourceTable: 'aiProviders', }, ], table: 'aiModels', uniqueConstraints: ['id', 'providerId'], }, { table: 'sessionGroups', uniqueConstraints: [], }, { fieldProcessors: { slug: (value) => (value ? `${value}-${uuid().slice(0, 8)}` : null), }, table: 'agents', uniqueConstraints: ['slug'], }, { // 对slug字段进行特殊处理 fieldProcessors: { slug: (value) => `${value}-${uuid().slice(0, 8)}`, }, relations: [ { field: 'groupId', sourceTable: 'sessionGroups', }, ], table: 'sessions', uniqueConstraints: ['slug'], }, { relations: [ { field: 'sessionId', sourceTable: 'sessions', }, ], table: 'topics', }, { conflictStrategy: 'skip', isCompositeKey: true, // 使用复合主键 [agentId, sessionId] relations: [ { field: 'agentId', sourceTable: 'agents', }, { field: 'sessionId', sourceTable: 'sessions', }, ], table: 'agentsToSessions', uniqueConstraints: ['agentId', 'sessionId'], }, { relations: [ { field: 'topicId', sourceTable: 'topics', }, ], selfReferences: [ { field: 'parentThreadId', }, ], table: 'threads', }, { relations: [ { field: 'sessionId', sourceTable: 'sessions', }, { field: 'topicId', sourceTable: 'topics', }, { field: 'agentId', sourceTable: 'agents', }, { field: 'threadId', sourceTable: 'threads', }, ], selfReferences: [ { field: 'parentId', }, { field: 'quotaId', }, ], table: 'messages', }, { conflictStrategy: 'skip', preserveId: true, // 使用消息ID作为主键 relations: [ { field: 'id', sourceTable: 'messages', }, ], table: 'messagePlugins', }, { isCompositeKey: true, // 使用复合主键 [messageId, chunkId] relations: [ { field: 'messageId', sourceTable: 'messages', }, { field: 'chunkId', sourceTable: 'chunks', }, ], table: 'messageChunks', }, { isCompositeKey: true, // 使用复合主键 [id, queryId, chunkId] relations: [ { field: 'id', sourceTable: 'messages', }, { field: 'queryId', sourceTable: 'messageQueries', }, { field: 'chunkId', sourceTable: 'chunks', }, ], table: 'messageQueryChunks', }, // { // relations: [ // { // field: 'messageId', // sourceTable: 'messages', // }, // { // field: 'embeddingsId', // sourceTable: 'embeddings', // }, // ], // table: 'messageQueries', // }, { conflictStrategy: 'skip', preserveId: true, // 使用消息ID作为主键 relations: [ { field: 'id', sourceTable: 'messages', }, ], table: 'messageTranslates', }, // { // conflictStrategy: 'skip', // preserveId: true, // 使用消息ID作为主键 // relations: [ // { // field: 'id', // sourceTable: 'messages', // }, // { // field: 'fileId', // sourceTable: 'files', // }, // ], // table: 'messageTTS', // }, ]; export class DataImporterRepos { private userId: string; private db: LobeChatDatabase; private deprecatedDataImporterRepos: DeprecatedDataImporterRepos; private idMaps: Record<string, Record<string, string>> = {}; private conflictRecords: Record<string, { field: string; value: any }[]> = {}; constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; this.db = db; this.deprecatedDataImporterRepos = new DeprecatedDataImporterRepos(db, userId); } importData = async (data: ImporterEntryData): Promise<ImportResultData> => { const results = await this.deprecatedDataImporterRepos.importData(data); return { results, success: true }; }; /** * 导入PostgreSQL数据 */ async importPgData( dbData: ImportPgDataStructure, conflictStrategy: ConflictStrategy = 'skip', ): Promise<ImportResultData> { const results: Record<string, ImportResult> = {}; const { data } = dbData; // 初始化ID映射表和冲突记录 this.idMaps = {}; this.conflictRecords = {}; try { await this.db.transaction(async (trx) => { // 按配置顺序导入表 for (const config of IMPORT_TABLE_CONFIG) { const { table: tableName } = config; // @ts-ignore const tableData = data[tableName]; if (!tableData || tableData.length === 0) { continue; } // 使用统一的导入方法 const result = await this.importTableData(trx, config, tableData, conflictStrategy); console.log(`imported table: ${tableName}, records: ${tableData.length}`); if (Object.values(result).some((value) => value > 0)) { results[tableName] = result; } } }); return { results, success: true }; } catch (error) { console.error('Import failed:', error); return { error: { details: this.extractErrorDetails(error), message: (error as any).message, }, results, success: false, }; } } /** * 从错误中提取详细信息 */ private extractErrorDetails(error: any) { if (error.code === '23505') { // PostgreSQL 唯一约束错误码 const match = error.detail?.match(/Key \((.+?)\)=\((.+?)\) already exists/); if (match) { return { constraintType: 'unique', field: match[1], value: match[2], }; } } return error.detail || 'Unknown error details'; } /** * 统一的表数据导入函数 - 处理所有类型的表 */ private async importTableData( trx: any, config: TableImportConfig, tableData: any[], // eslint-disable-next-line @typescript-eslint/no-unused-vars _userConflictStrategy: ConflictStrategy, ): Promise<ImportResult> { const { table: tableName, preserveId, isCompositeKey = false, uniqueConstraints = [], conflictStrategy = 'override', fieldProcessors = {}, relations = [], selfReferences = [], } = config; // @ts-ignore const table = EXPORT_TABLES[tableName]; const result: ImportResult = { added: 0, errors: 0, skips: 0, updated: 0 }; // 初始化该表的ID映射 if (!this.idMaps[tableName]) { this.idMaps[tableName] = {}; } try { // 1. 查找已存在的记录(基于clientId和userId) let existingRecords: any[] = []; if ('clientId' in table && 'userId' in table) { const clientIds = tableData.map((item) => item.clientId || item.id).filter(Boolean); if (clientIds.length > 0) { existingRecords = await trx.query[tableName].findMany({ where: and(eq(table.userId, this.userId), inArray(table.clientId, clientIds)), }); } } // 如果需要保留原始ID,还需要检查ID是否已存在 if (preserveId && !isCompositeKey) { const ids = tableData.map((item) => item.id).filter(Boolean); if (ids.length > 0) { const idExistingRecords = await trx.query[tableName].findMany({ where: inArray(table.id, ids), }); // 合并到已存在记录集合中 existingRecords = [ ...existingRecords, ...idExistingRecords.filter( (record: any) => !existingRecords.some((existing) => existing.id === record.id), ), ]; } } result.skips = existingRecords.length; // 2. 为已存在的记录建立ID映射 for (const record of existingRecords) { // 只有非复合主键表才需要ID映射 if (!isCompositeKey) { this.idMaps[tableName][record.id] = record.id; if (record.clientId) { this.idMaps[tableName][record.clientId] = record.id; } // 记录中可能使用的任何其他ID标识符 const originalRecord = tableData.find( (item) => item.id === record.id || item.clientId === record.clientId, ); if (originalRecord) { // 确保原始记录ID也映射到数据库记录ID this.idMaps[tableName][originalRecord.id] = record.id; } } } // 3. 筛选出需要插入的记录 const recordsToInsert = tableData.filter( (item) => !existingRecords.some( (record) => (record.clientId === (item.clientId || item.id) && record.clientId) || (preserveId && !isCompositeKey && record.id === item.id), ), ); if (recordsToInsert.length === 0) { return result; } // 4. 准备导入数据 const preparedData = recordsToInsert.map((item) => { const originalId = item.id; // 处理日期字段 const dateFields: any = {}; if (item.createdAt) dateFields.createdAt = new Date(item.createdAt); if (item.updatedAt) dateFields.updatedAt = new Date(item.updatedAt); if (item.accessedAt) dateFields.accessedAt = new Date(item.accessedAt); // 创建新记录对象 let newRecord: any = {}; // 根据是否复合主键和是否保留ID决定如何处理 if (isCompositeKey) { // 对于复合主键表,不包含id字段 // eslint-disable-next-line @typescript-eslint/no-unused-vars const { id: _, ...rest } = item; newRecord = { ...rest, ...dateFields, clientId: item.clientId || item.id, userId: this.userId, }; } else { // 非复合主键表处理 newRecord = { ...(preserveId ? item : { ...item, id: undefined }), ...dateFields, clientId: item.clientId || item.id, userId: this.userId, }; } // 应用字段处理器 for (const field in fieldProcessors) { if (newRecord[field] !== undefined) { newRecord[field] = fieldProcessors[field](newRecord[field]); } } // 特殊表处理 if (tableName === 'userSettings') { newRecord.id = this.userId; } // 处理关系字段(外键引用) for (const relation of relations) { const { field, sourceTable } = relation; if (newRecord[field] && this.idMaps[sourceTable]) { const mappedId = this.idMaps[sourceTable][newRecord[field]]; if (mappedId) { newRecord[field] = mappedId; } else { // 找不到映射,设为null console.warn( `Could not find mapped ID for ${field}=${newRecord[field]} in table ${sourceTable}`, ); newRecord[field] = null; } } } // 简化处理自引用字段 - 直接设为null for (const selfRef of selfReferences) { const { field } = selfRef; if (newRecord[field] !== undefined) { newRecord[field] = null; } } return { newRecord, originalId }; }); // 5. 检查唯一约束并应用冲突策略 for (const record of preparedData) { if (isCompositeKey && uniqueConstraints.length > 0) { // 对于复合主键表,将所有唯一约束字段作为一个组合条件 const whereConditions = uniqueConstraints .filter((field) => record.newRecord[field] !== undefined) .map((field) => eq(table[field], record.newRecord[field])); // 添加userId条件(如果表有userId字段) if ('userId' in table) { whereConditions.push(eq(table.userId, this.userId)); } if (whereConditions.length > 0) { const exists = await trx.query[tableName].findFirst({ where: and(...whereConditions), }); if (exists) { // 记录冲突 if (!this.conflictRecords[tableName]) this.conflictRecords[tableName] = []; this.conflictRecords[tableName].push({ field: uniqueConstraints.join(','), value: uniqueConstraints .map((field) => `${field}=${record.newRecord[field]}`) .join(','), }); // 应用冲突策略 switch (conflictStrategy) { case 'skip': { record.newRecord._skip = true; result.skips++; // 关键改进:即使跳过,也建立ID映射关系 if (!isCompositeKey) { this.idMaps[tableName][record.originalId] = exists.id; if (record.newRecord.clientId) { this.idMaps[tableName][record.newRecord.clientId] = exists.id; } } break; } case 'override': { // 不需要额外操作,插入时会覆盖 break; } case 'merge': { // 合并数据 await trx .update(table) .set(record.newRecord) .where(and(...whereConditions)); record.newRecord._skip = true; if (result.updated) result.updated++; else { result.updated = 1; } break; } } } } } else { // 处理唯一约束 for (const field of uniqueConstraints) { if (!record.newRecord[field]) continue; // 检查字段值是否已存在 const exists = await trx.query[tableName].findFirst({ where: eq(table[field], record.newRecord[field]), }); if (exists) { // 记录冲突 if (!this.conflictRecords[tableName]) this.conflictRecords[tableName] = []; this.conflictRecords[tableName].push({ field, value: record.newRecord[field], }); // 应用冲突策略 switch (conflictStrategy) { case 'skip': { record.newRecord._skip = true; result.skips++; // 关键改进:即使跳过,也建立ID映射关系 if (!isCompositeKey) { this.idMaps[tableName][record.originalId] = exists.id; if (record.newRecord.clientId) { this.idMaps[tableName][record.newRecord.clientId] = exists.id; } } break; } case 'override': { // 应用字段处理器 if (field in fieldProcessors) { record.newRecord[field] = fieldProcessors[field](record.newRecord[field]); } break; } case 'merge': { // 合并数据 await trx .update(table) .set(record.newRecord) .where(eq(table[field], record.newRecord[field])); record.newRecord._skip = true; if (result.updated) result.updated++; else { result.updated = 1; } break; } } } } } } // 过滤掉标记为跳过的记录 const filteredData = preparedData.filter((record) => !record.newRecord._skip); // 清除临时标记 filteredData.forEach((record) => delete record.newRecord._skip); // 6. 批量插入数据 const BATCH_SIZE = 100; for (let i = 0; i < filteredData.length; i += BATCH_SIZE) { const batch = filteredData.slice(i, i + BATCH_SIZE).filter(Boolean); const itemsToInsert = batch.map((item) => item.newRecord); const originalIds = batch.map((item) => item.originalId); try { // 插入并返回结果 const insertQuery = trx.insert(table).values(itemsToInsert); let insertResult; // 只对非复合主键表需要返回ID if (!isCompositeKey) { const res = await insertQuery.returning(); insertResult = res.map((item: any) => ({ clientId: 'clientId' in item ? item.clientId : undefined, id: item.id, })); } else { await insertQuery; insertResult = itemsToInsert.map(() => ({})); // 创建空结果以维持计数 } result.added += insertResult.length; // 建立ID映射关系 (只对非复合主键表) if (!isCompositeKey) { for (const [j, newRecord] of insertResult.entries()) { const originalId = originalIds[j]; this.idMaps[tableName][originalId] = newRecord.id; // 同时确保clientId也能映射到正确的ID const originalRecord = tableData.find((item) => item.id === originalId); if (originalRecord && originalRecord.clientId) { this.idMaps[tableName][originalRecord.clientId] = newRecord.id; } } } } catch (error) { console.error(`Error batch inserting ${tableName}:`, error); // 处理错误并记录 if ((error as any).code === '23505') { const match = (error as any).detail?.match(/Key \((.+?)\)=\((.+?)\) already exists/); if (match) { const conflictField = match[1]; if (!this.conflictRecords[tableName]) this.conflictRecords[tableName] = []; this.conflictRecords[tableName].push({ field: conflictField, value: match[2], }); } } result.errors += batch.length; } } return result; } catch (error) { console.error(`Error importing table ${tableName}:`, error); result.errors = tableData.length; return result; } } }