@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
701 lines (604 loc) • 21.1 kB
text/typescript
import type { HeatmapsProps } from '@lobehub/charts';
import dayjs from 'dayjs';
import { count, sql } from 'drizzle-orm';
import { and, asc, desc, eq, gt, inArray, isNotNull, isNull, like } from 'drizzle-orm/expressions';
import { LobeChatDatabase } from '@/database/type';
import {
genEndDateWhere,
genRangeWhere,
genStartDateWhere,
genWhere,
} from '@/database/utils/genWhere';
import { idGenerator } from '@/database/utils/idGenerator';
import {
ChatFileItem,
ChatImageItem,
ChatMessage,
ChatTTS,
ChatToolPayload,
ChatTranslate,
CreateMessageParams,
MessageItem,
ModelRankItem,
NewMessageQueryParams,
UpdateMessageParams,
} from '@/types/message';
import { merge } from '@/utils/merge';
import { today } from '@/utils/time';
import {
MessagePluginItem,
chunks,
documents,
embeddings,
fileChunks,
files,
messagePlugins,
messageQueries,
messageQueryChunks,
messageTTS,
messageTranslates,
messages,
messagesFiles,
} from '../schemas';
export interface QueryMessageParams {
current?: number;
pageSize?: number;
sessionId?: string | null;
topicId?: string | null;
}
export class MessageModel {
private userId: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
this.db = db;
}
// **************** Query *************** //
query = async (
{ current = 0, pageSize = 1000, sessionId, topicId }: QueryMessageParams = {},
options: {
postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise<string>;
} = {},
) => {
const offset = current * pageSize;
// 1. get basic messages
const result = await this.db
.select({
/* eslint-disable sort-keys-fix/sort-keys-fix*/
id: messages.id,
role: messages.role,
content: messages.content,
reasoning: messages.reasoning,
search: messages.search,
metadata: messages.metadata,
error: messages.error,
model: messages.model,
provider: messages.provider,
createdAt: messages.createdAt,
updatedAt: messages.updatedAt,
parentId: messages.parentId,
threadId: messages.threadId,
tools: messages.tools,
tool_call_id: messagePlugins.toolCallId,
plugin: {
apiName: messagePlugins.apiName,
arguments: messagePlugins.arguments,
identifier: messagePlugins.identifier,
type: messagePlugins.type,
},
pluginError: messagePlugins.error,
pluginState: messagePlugins.state,
translate: {
content: messageTranslates.content,
from: messageTranslates.from,
to: messageTranslates.to,
},
ttsId: messageTTS.id,
ttsContentMd5: messageTTS.contentMd5,
ttsFile: messageTTS.fileId,
ttsVoice: messageTTS.voice,
/* eslint-enable */
})
.from(messages)
.where(
and(
eq(messages.userId, this.userId),
this.matchSession(sessionId),
this.matchTopic(topicId),
),
)
.leftJoin(messagePlugins, eq(messagePlugins.id, messages.id))
.leftJoin(messageTranslates, eq(messageTranslates.id, messages.id))
.leftJoin(messageTTS, eq(messageTTS.id, messages.id))
.orderBy(asc(messages.createdAt))
.limit(pageSize)
.offset(offset);
const messageIds = result.map((message) => message.id as string);
if (messageIds.length === 0) return [];
// 2. get relative files
const rawRelatedFileList = await this.db
.select({
fileType: files.fileType,
id: messagesFiles.fileId,
messageId: messagesFiles.messageId,
name: files.name,
size: files.size,
url: files.url,
})
.from(messagesFiles)
.leftJoin(files, eq(files.id, messagesFiles.fileId))
.where(inArray(messagesFiles.messageId, messageIds));
const relatedFileList = await Promise.all(
rawRelatedFileList.map(async (file) => ({
...file,
url: options.postProcessUrl
? await options.postProcessUrl(file.url, file as any)
: (file.url as string),
})),
);
// 获取关联的文档内容
const fileIds = relatedFileList.map((file) => file.id).filter(Boolean);
let documentsMap: Record<string, string> = {};
if (fileIds.length > 0) {
const documentsList = await this.db
.select({
content: documents.content,
fileId: documents.fileId,
})
.from(documents)
.where(inArray(documents.fileId, fileIds));
documentsMap = documentsList.reduce(
(acc, doc) => {
if (doc.fileId) acc[doc.fileId] = doc.content as string;
return acc;
},
{} as Record<string, string>,
);
}
const imageList = relatedFileList.filter((i) => (i.fileType || '').startsWith('image'));
const fileList = relatedFileList.filter((i) => !(i.fileType || '').startsWith('image'));
// 3. get relative file chunks
const chunksList = await this.db
.select({
fileId: files.id,
fileType: files.fileType,
fileUrl: files.url,
filename: files.name,
id: chunks.id,
messageId: messageQueryChunks.messageId,
similarity: messageQueryChunks.similarity,
text: chunks.text,
})
.from(messageQueryChunks)
.leftJoin(chunks, eq(chunks.id, messageQueryChunks.chunkId))
.leftJoin(fileChunks, eq(fileChunks.chunkId, chunks.id))
.innerJoin(files, eq(fileChunks.fileId, files.id))
.where(inArray(messageQueryChunks.messageId, messageIds));
// 3. get relative message query
const messageQueriesList = await this.db
.select({
id: messageQueries.id,
messageId: messageQueries.messageId,
rewriteQuery: messageQueries.rewriteQuery,
userQuery: messageQueries.userQuery,
})
.from(messageQueries)
.where(inArray(messageQueries.messageId, messageIds));
return result.map(
({ model, provider, translate, ttsId, ttsFile, ttsContentMd5, ttsVoice, ...item }) => {
const messageQuery = messageQueriesList.find((relation) => relation.messageId === item.id);
return {
...item,
chunksList: chunksList
.filter((relation) => relation.messageId === item.id)
.map((c) => ({
...c,
similarity: Number(c.similarity) ?? undefined,
})),
extra: {
fromModel: model,
fromProvider: provider,
translate,
tts: ttsId
? {
contentMd5: ttsContentMd5,
file: ttsFile,
voice: ttsVoice,
}
: undefined,
},
fileList: fileList
.filter((relation) => relation.messageId === item.id)
// eslint-disable-next-line @typescript-eslint/no-unused-vars
.map<ChatFileItem>(({ id, url, size, fileType, name }) => ({
content: documentsMap[id],
fileType: fileType!,
id,
name: name!,
size: size!,
url,
})),
imageList: imageList
.filter((relation) => relation.messageId === item.id)
// eslint-disable-next-line @typescript-eslint/no-unused-vars
.map<ChatImageItem>(({ id, url, name }) => ({ alt: name!, id, url })),
meta: {},
ragQuery: messageQuery?.rewriteQuery,
ragQueryId: messageQuery?.id,
ragRawQuery: messageQuery?.userQuery,
} as unknown as ChatMessage;
},
);
};
findById = async (id: string) => {
return this.db.query.messages.findFirst({
where: and(eq(messages.id, id), eq(messages.userId, this.userId)),
});
};
findMessageQueriesById = async (messageId: string) => {
const result = await this.db
.select({
embeddings: embeddings.embeddings,
id: messageQueries.id,
query: messageQueries.rewriteQuery,
rewriteQuery: messageQueries.rewriteQuery,
userQuery: messageQueries.userQuery,
})
.from(messageQueries)
.where(and(eq(messageQueries.messageId, messageId)))
.leftJoin(embeddings, eq(embeddings.id, messageQueries.embeddingsId));
if (result.length === 0) return undefined;
return result[0];
};
queryAll = async () => {
const result = await this.db
.select()
.from(messages)
.orderBy(messages.createdAt)
.where(eq(messages.userId, this.userId));
return result as MessageItem[];
};
queryBySessionId = async (sessionId?: string | null) => {
const result = await this.db.query.messages.findMany({
orderBy: [asc(messages.createdAt)],
where: and(eq(messages.userId, this.userId), this.matchSession(sessionId)),
});
return result as MessageItem[];
};
queryByKeyword = async (keyword: string) => {
if (!keyword) return [];
const result = await this.db.query.messages.findMany({
orderBy: [desc(messages.createdAt)],
where: and(eq(messages.userId, this.userId), like(messages.content, `%${keyword}%`)),
});
return result as MessageItem[];
};
count = async (params?: {
endDate?: string;
range?: [string, string];
startDate?: string;
}): Promise<number> => {
const result = await this.db
.select({
count: count(messages.id),
})
.from(messages)
.where(
genWhere([
eq(messages.userId, this.userId),
params?.range
? genRangeWhere(params.range, messages.createdAt, (date) => date.toDate())
: undefined,
params?.endDate
? genEndDateWhere(params.endDate, messages.createdAt, (date) => date.toDate())
: undefined,
params?.startDate
? genStartDateWhere(params.startDate, messages.createdAt, (date) => date.toDate())
: undefined,
]),
);
return result[0].count;
};
countWords = async (params?: {
endDate?: string;
range?: [string, string];
startDate?: string;
}): Promise<number> => {
const result = await this.db
.select({
count: sql<string>`sum(length(${messages.content}))`.as('total_length'),
})
.from(messages)
.where(
genWhere([
eq(messages.userId, this.userId),
params?.range
? genRangeWhere(params.range, messages.createdAt, (date) => date.toDate())
: undefined,
params?.endDate
? genEndDateWhere(params.endDate, messages.createdAt, (date) => date.toDate())
: undefined,
params?.startDate
? genStartDateWhere(params.startDate, messages.createdAt, (date) => date.toDate())
: undefined,
]),
);
return Number(result[0].count);
};
rankModels = async (limit: number = 10): Promise<ModelRankItem[]> => {
return this.db
.select({
count: count(messages.id).as('count'),
id: messages.model,
})
.from(messages)
.where(and(eq(messages.userId, this.userId), isNotNull(messages.model)))
.having(({ count }) => gt(count, 0))
.groupBy(messages.model)
.orderBy(desc(sql`count`), asc(messages.model))
.limit(limit);
};
getHeatmaps = async (): Promise<HeatmapsProps['data']> => {
const startDate = today().subtract(1, 'year').startOf('day');
const endDate = today().endOf('day');
const result = await this.db
.select({
count: count(messages.id),
date: sql`DATE(${messages.createdAt})`.as('heatmaps_date'),
})
.from(messages)
.where(
genWhere([
eq(messages.userId, this.userId),
genRangeWhere(
[startDate.format('YYYY-MM-DD'), endDate.add(1, 'day').format('YYYY-MM-DD')],
messages.createdAt,
(date) => date.toDate(),
),
]),
)
.groupBy(sql`heatmaps_date`)
.orderBy(desc(sql`heatmaps_date`));
const heatmapData: HeatmapsProps['data'] = [];
let currentDate = startDate.clone();
const dateCountMap = new Map<string, number>();
for (const item of result) {
if (item?.date) {
const dateStr = dayjs(item.date as string).format('YYYY-MM-DD');
dateCountMap.set(dateStr, Number(item.count) || 0);
}
}
while (currentDate.isBefore(endDate) || currentDate.isSame(endDate, 'day')) {
const formattedDate = currentDate.format('YYYY-MM-DD');
const count = dateCountMap.get(formattedDate) || 0;
const levelCount = count > 0 ? Math.ceil(count / 5) : 0;
const level = levelCount > 4 ? 4 : levelCount;
heatmapData.push({
count,
date: formattedDate,
level,
});
currentDate = currentDate.add(1, 'day');
}
return heatmapData;
};
hasMoreThanN = async (n: number): Promise<boolean> => {
const result = await this.db
.select({ id: messages.id })
.from(messages)
.where(eq(messages.userId, this.userId))
.limit(n + 1);
return result.length > n;
};
// **************** Create *************** //
create = async (
{
fromModel,
fromProvider,
files,
plugin,
pluginState,
fileChunks,
ragQueryId,
updatedAt,
createdAt,
...message
}: CreateMessageParams,
id: string = this.genId(),
): Promise<MessageItem> => {
return this.db.transaction(async (trx) => {
const [item] = (await trx
.insert(messages)
.values({
...message,
// TODO: remove this when the client is updated
createdAt: createdAt ? new Date(createdAt) : undefined,
id,
model: fromModel,
provider: fromProvider,
updatedAt: updatedAt ? new Date(updatedAt) : undefined,
userId: this.userId,
})
.returning()) as MessageItem[];
// Insert the plugin data if the message is a tool
if (message.role === 'tool') {
await trx.insert(messagePlugins).values({
apiName: plugin?.apiName,
arguments: plugin?.arguments,
id,
identifier: plugin?.identifier,
state: pluginState,
toolCallId: message.tool_call_id,
type: plugin?.type,
userId: this.userId,
});
}
if (files && files.length > 0) {
await trx
.insert(messagesFiles)
.values(files.map((file) => ({ fileId: file, messageId: id, userId: this.userId })));
}
if (fileChunks && fileChunks.length > 0 && ragQueryId) {
await trx.insert(messageQueryChunks).values(
fileChunks.map((chunk) => ({
chunkId: chunk.id,
messageId: id,
queryId: ragQueryId,
similarity: chunk.similarity?.toString(),
userId: this.userId,
})),
);
}
return item;
});
};
batchCreate = async (newMessages: MessageItem[]) => {
const messagesToInsert = newMessages.map((m) => {
// TODO: need a better way to handle this
return { ...m, role: m.role as any, userId: this.userId };
});
return this.db.insert(messages).values(messagesToInsert);
};
createMessageQuery = async (params: NewMessageQueryParams) => {
const result = await this.db
.insert(messageQueries)
.values({ ...params, userId: this.userId })
.returning();
return result[0];
};
// **************** Update *************** //
update = async (id: string, { imageList, ...message }: Partial<UpdateMessageParams>) => {
return this.db.transaction(async (trx) => {
// 1. insert message files
if (imageList && imageList.length > 0) {
await trx
.insert(messagesFiles)
.values(
imageList.map((file) => ({ fileId: file.id, messageId: id, userId: this.userId })),
);
}
return trx
.update(messages)
.set({
...message,
// TODO: need a better way to handle this
// TODO: but I forget why 🤡
role: message.role as any,
})
.where(and(eq(messages.id, id), eq(messages.userId, this.userId)));
});
};
updatePluginState = async (id: string, state: Record<string, any>) => {
const item = await this.db.query.messagePlugins.findFirst({
where: eq(messagePlugins.id, id),
});
if (!item) throw new Error('Plugin not found');
return this.db
.update(messagePlugins)
.set({ state: merge(item.state || {}, state) })
.where(eq(messagePlugins.id, id));
};
updateMessagePlugin = async (id: string, value: Partial<MessagePluginItem>) => {
const item = await this.db.query.messagePlugins.findFirst({
where: eq(messagePlugins.id, id),
});
if (!item) throw new Error('Plugin not found');
return this.db.update(messagePlugins).set(value).where(eq(messagePlugins.id, id));
};
updateTranslate = async (id: string, translate: Partial<ChatTranslate>) => {
const result = await this.db.query.messageTranslates.findFirst({
where: and(eq(messageTranslates.id, id)),
});
// If the message does not exist in the translate table, insert it
if (!result) {
return this.db.insert(messageTranslates).values({ ...translate, id, userId: this.userId });
}
// or just update the existing one
return this.db.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id));
};
updateTTS = async (id: string, tts: Partial<ChatTTS>) => {
const result = await this.db.query.messageTTS.findFirst({
where: and(eq(messageTTS.id, id)),
});
// If the message does not exist in the translate table, insert it
if (!result) {
return this.db.insert(messageTTS).values({
contentMd5: tts.contentMd5,
fileId: tts.file,
id,
userId: this.userId,
voice: tts.voice,
});
}
// or just update the existing one
return this.db
.update(messageTTS)
.set({ contentMd5: tts.contentMd5, fileId: tts.file, voice: tts.voice })
.where(eq(messageTTS.id, id));
};
// **************** Delete *************** //
deleteMessage = async (id: string) => {
return this.db.transaction(async (tx) => {
// 1. 查询要删除的 message 的完整信息
const message = await tx
.select()
.from(messages)
.where(and(eq(messages.id, id), eq(messages.userId, this.userId)))
.limit(1);
// 如果找不到要删除的 message,直接返回
if (message.length === 0) return;
// 2. 检查 message 是否包含 tools
const toolCallIds = (message[0].tools as ChatToolPayload[])
?.map((tool) => tool.id)
.filter(Boolean);
let relatedMessageIds: string[] = [];
if (toolCallIds?.length > 0) {
// 3. 如果 message 包含 tools,查询出所有相关联的 message id
const res = await tx
.select({ id: messagePlugins.id })
.from(messagePlugins)
.where(inArray(messagePlugins.toolCallId, toolCallIds));
relatedMessageIds = res.map((row) => row.id);
}
// 4. 合并要删除的 message id 列表
const messageIdsToDelete = [id, ...relatedMessageIds];
// 5. 删除所有相关的 message
await tx.delete(messages).where(inArray(messages.id, messageIdsToDelete));
});
};
deleteMessages = async (ids: string[]) =>
this.db
.delete(messages)
.where(and(eq(messages.userId, this.userId), inArray(messages.id, ids)));
deleteMessageTranslate = async (id: string) =>
this.db
.delete(messageTranslates)
.where(and(eq(messageTranslates.id, id), eq(messageTranslates.userId, this.userId)));
deleteMessageTTS = async (id: string) =>
this.db
.delete(messageTTS)
.where(and(eq(messageTTS.id, id), eq(messageTTS.userId, this.userId)));
deleteMessageQuery = async (id: string) =>
this.db
.delete(messageQueries)
.where(and(eq(messageQueries.id, id), eq(messageQueries.userId, this.userId)));
deleteMessagesBySession = async (sessionId?: string | null, topicId?: string | null) =>
this.db
.delete(messages)
.where(
and(
eq(messages.userId, this.userId),
this.matchSession(sessionId),
this.matchTopic(topicId),
),
);
deleteAllMessages = async () => {
return this.db.delete(messages).where(eq(messages.userId, this.userId));
};
// **************** Helper *************** //
private genId = () => idGenerator('messages', 14);
private matchSession = (sessionId?: string | null) =>
sessionId ? eq(messages.sessionId, sessionId) : isNull(messages.sessionId);
private matchTopic = (topicId?: string | null) =>
topicId ? eq(messages.topicId, topicId) : isNull(messages.topicId);
}