@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
327 lines (284 loc) • 11.1 kB
text/typescript
import { sql } from 'drizzle-orm';
import { and, eq, inArray } from 'drizzle-orm/expressions';
import {
agents,
agentsToSessions,
messagePlugins,
messageTranslates,
messages,
sessionGroups,
sessions,
topics,
} from '@/database/schemas';
import { LobeChatDatabase } from '@/database/type';
import { ImportResult } from '@/services/import/_deprecated';
import { ImporterEntryData } from '@/types/importer';
import { sanitizeUTF8 } from '@/utils/sanitizeUTF8';
export class DeprecatedDataImporterRepos {
private userId: string;
private db: LobeChatDatabase;
/**
* The version of the importer that this module supports
*/
supportVersion = 7;
constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
this.db = db;
}
importData = async (data: ImporterEntryData) => {
if (data.version > this.supportVersion) throw new Error('Unsupported version');
let sessionGroupResult: ImportResult = { added: 0, errors: 0, skips: 0 };
let sessionResult: ImportResult = { added: 0, errors: 0, skips: 0 };
let topicResult: ImportResult = { added: 0, errors: 0, skips: 0 };
let messageResult: ImportResult = { added: 0, errors: 0, skips: 0 };
let sessionGroupIdMap: Record<string, string> = {};
let sessionIdMap: Record<string, string> = {};
let topicIdMap: Record<string, string> = {};
await this.db.transaction(async (trx) => {
// import sessionGroups
if (data.sessionGroups && data.sessionGroups.length > 0) {
const query = await trx.query.sessionGroups.findMany({
where: and(
eq(sessionGroups.userId, this.userId),
inArray(
sessionGroups.clientId,
data.sessionGroups.map(({ id }) => id),
),
),
});
sessionGroupResult.skips = query.length;
const mapArray = await trx
.insert(sessionGroups)
.values(
data.sessionGroups.map(({ id, createdAt, updatedAt, ...res }) => ({
...res,
clientId: id,
createdAt: new Date(createdAt),
updatedAt: new Date(updatedAt),
userId: this.userId,
})),
)
.onConflictDoUpdate({
set: { updatedAt: new Date() },
target: [sessionGroups.clientId, sessionGroups.userId],
})
.returning({ clientId: sessionGroups.clientId, id: sessionGroups.id });
sessionGroupResult.added = mapArray.length - query.length;
sessionGroupIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id]));
}
// import sessions
if (data.sessions && data.sessions.length > 0) {
const query = await trx.query.sessions.findMany({
where: and(
eq(sessions.userId, this.userId),
inArray(
sessions.clientId,
data.sessions.map(({ id }) => id),
),
),
});
sessionResult.skips = query.length;
const mapArray = await trx
.insert(sessions)
.values(
data.sessions.map(({ id, createdAt, updatedAt, group, ...res }) => ({
...res,
clientId: id,
createdAt: new Date(createdAt),
groupId: group ? sessionGroupIdMap[group] : null,
updatedAt: new Date(updatedAt),
userId: this.userId,
})),
)
.onConflictDoUpdate({
set: { updatedAt: new Date() },
target: [sessions.clientId, sessions.userId],
})
.returning({ clientId: sessions.clientId, id: sessions.id });
// get the session client-server id map
sessionIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id]));
// update added count
sessionResult.added = mapArray.length - query.length;
const shouldInsertSessionAgents = data.sessions
// filter out existing session, only insert new ones
.filter((s) => query.every((q) => q.clientId !== s.id));
// 只有当需要有新的 session 时,才会插入 agent
if (shouldInsertSessionAgents.length > 0) {
const agentMapArray = await trx
.insert(agents)
.values(
shouldInsertSessionAgents.map(({ config, meta }) => ({
...config,
...meta,
userId: this.userId,
})),
)
.returning({ id: agents.id });
await trx.insert(agentsToSessions).values(
shouldInsertSessionAgents.map(({ id }, index) => ({
agentId: agentMapArray[index].id,
sessionId: sessionIdMap[id],
userId: this.userId,
})),
);
}
}
// import topics
if (data.topics && data.topics.length > 0) {
const skipQuery = await trx.query.topics.findMany({
where: and(
eq(topics.userId, this.userId),
inArray(
topics.clientId,
data.topics.map(({ id }) => id),
),
),
});
topicResult.skips = skipQuery.length;
const mapArray = await trx
.insert(topics)
.values(
data.topics.map(({ id, createdAt, updatedAt, sessionId, favorite, ...res }) => ({
...res,
clientId: id,
createdAt: new Date(createdAt),
favorite: Boolean(favorite),
sessionId: sessionId ? sessionIdMap[sessionId] : null,
updatedAt: new Date(updatedAt),
userId: this.userId,
})),
)
.onConflictDoUpdate({
set: { updatedAt: new Date() },
target: [topics.clientId, topics.userId],
})
.returning({ clientId: topics.clientId, id: topics.id });
topicIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id]));
topicResult.added = mapArray.length - skipQuery.length;
}
// import messages
if (data.messages && data.messages.length > 0) {
// 1. find skip ones
console.time('find messages');
const skipQuery = await trx.query.messages.findMany({
where: and(
eq(messages.userId, this.userId),
inArray(
messages.clientId,
data.messages.map(({ id }) => id),
),
),
});
console.timeEnd('find messages');
messageResult.skips = skipQuery.length;
// filter out existing messages, only insert new ones
const shouldInsertMessages = data.messages.filter((s) =>
skipQuery.every((q) => q.clientId !== s.id),
);
// 2. insert messages
if (shouldInsertMessages.length > 0) {
const inertValues = shouldInsertMessages.map(
({ id, extra, createdAt, updatedAt, sessionId, topicId, content, ...res }) => ({
...res,
clientId: id,
content: sanitizeUTF8(content),
createdAt: new Date(createdAt),
model: extra?.fromModel,
parentId: null,
provider: extra?.fromProvider,
sessionId: sessionId ? sessionIdMap[sessionId] : null,
topicId: topicId ? topicIdMap[topicId] : null, // 暂时设为 NULL
updatedAt: new Date(updatedAt),
userId: this.userId,
}),
);
console.time('insert messages');
const BATCH_SIZE = 100; // 每批次插入的记录数
for (let i = 0; i < inertValues.length; i += BATCH_SIZE) {
const batch = inertValues.slice(i, i + BATCH_SIZE);
await trx.insert(messages).values(batch);
}
console.timeEnd('insert messages');
const messageIdArray = await trx
.select({ clientId: messages.clientId, id: messages.id })
.from(messages)
.where(
and(
eq(messages.userId, this.userId),
inArray(
messages.clientId,
data.messages.map(({ id }) => id),
),
),
);
const messageIdMap = Object.fromEntries(
messageIdArray.map(({ clientId, id }) => [clientId, id]),
);
// 3. update parentId for messages
console.time('execute updates parentId');
const parentIdUpdates = shouldInsertMessages
.filter((msg) => msg.parentId) // 只处理有 parentId 的消息
.map((msg) => {
if (messageIdMap[msg.parentId as string])
return sql`WHEN ${messages.clientId} = ${msg.id} THEN ${messageIdMap[msg.parentId as string]} `;
return undefined;
})
.filter(Boolean);
if (parentIdUpdates.length > 0) {
await trx
.update(messages)
.set({
parentId: sql`CASE ${sql.join(parentIdUpdates)} END`,
})
.where(
inArray(
messages.clientId,
data.messages.map((msg) => msg.id),
),
);
// if needed, you can print the sql and params
// const SQL = updateQuery.toSQL();
// console.log('sql:', SQL.sql);
// console.log('params:', SQL.params);
}
console.timeEnd('execute updates parentId');
// 4. insert message plugins
const pluginInserts = shouldInsertMessages.filter((msg) => msg.plugin);
if (pluginInserts.length > 0) {
await trx.insert(messagePlugins).values(
pluginInserts.map((msg) => ({
apiName: msg.plugin?.apiName,
arguments: msg.plugin?.arguments,
id: messageIdMap[msg.id],
identifier: msg.plugin?.identifier,
state: msg.pluginState,
toolCallId: msg.tool_call_id,
type: msg.plugin?.type,
userId: this.userId,
})),
);
}
// 5. insert message translate
const translateInserts = shouldInsertMessages.filter((msg) => msg.extra?.translate);
if (translateInserts.length > 0) {
await trx.insert(messageTranslates).values(
translateInserts.map((msg) => ({
id: messageIdMap[msg.id],
...msg.extra?.translate,
userId: this.userId,
})),
);
}
// TODO: 未来需要处理 TTS 和图片的插入 (目前存在 file 的部分,不方便处理)
}
messageResult.added = shouldInsertMessages.length;
}
});
return {
messages: messageResult,
sessionGroups: sessionGroupResult,
sessions: sessionResult,
topics: topicResult,
};
};
}