@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
476 lines (403 loc) • 14 kB
text/typescript
import { Column, count, sql } from 'drizzle-orm';
import { and, asc, desc, eq, gt, inArray, isNull, like, not, or } from 'drizzle-orm/expressions';
import { DeepPartial } from 'utility-types';
import { DEFAULT_INBOX_AVATAR } from '@/const/meta';
import { INBOX_SESSION_ID } from '@/const/session';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
import { LobeChatDatabase } from '@/database/type';
import {
genEndDateWhere,
genRangeWhere,
genStartDateWhere,
genWhere,
} from '@/database/utils/genWhere';
import { idGenerator } from '@/database/utils/idGenerator';
import { LobeAgentConfig } from '@/types/agent';
import { ChatSessionList, LobeAgentSession, SessionRankItem } from '@/types/session';
import { merge } from '@/utils/merge';
import {
AgentItem,
NewAgent,
NewSession,
SessionItem,
agents,
agentsToSessions,
sessionGroups,
sessions,
topics,
} from '../schemas';
export class SessionModel {
private userId: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
this.db = db;
}
// **************** Query *************** //
query = async ({ current = 0, pageSize = 9999 } = {}) => {
const offset = current * pageSize;
return this.db.query.sessions.findMany({
limit: pageSize,
offset,
orderBy: [desc(sessions.updatedAt)],
where: and(eq(sessions.userId, this.userId), not(eq(sessions.slug, INBOX_SESSION_ID))),
with: { agentsToSessions: { columns: {}, with: { agent: true } }, group: true },
});
};
queryWithGroups = async (): Promise<ChatSessionList> => {
// 查询所有会话
const result = await this.query();
const groups = await this.db.query.sessionGroups.findMany({
orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)],
where: eq(sessions.userId, this.userId),
});
return {
sessionGroups: groups as unknown as ChatSessionList['sessionGroups'],
sessions: result.map((item) => this.mapSessionItem(item as any)),
};
};
queryByKeyword = async (keyword: string) => {
if (!keyword) return [];
const keywordLowerCase = keyword.toLowerCase();
const data = await this.findSessionsByKeywords({ keyword: keywordLowerCase });
return data.map((item) => this.mapSessionItem(item as any));
};
findByIdOrSlug = async (
idOrSlug: string,
): Promise<(SessionItem & { agent: AgentItem }) | undefined> => {
const result = await this.db.query.sessions.findFirst({
where: and(
or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)),
eq(sessions.userId, this.userId),
),
with: { agentsToSessions: { columns: {}, with: { agent: true } }, group: true },
});
if (!result) return;
return { ...result, agent: (result?.agentsToSessions?.[0] as any)?.agent } as any;
};
count = async (params?: {
endDate?: string;
range?: [string, string];
startDate?: string;
}): Promise<number> => {
const result = await this.db
.select({
count: count(sessions.id),
})
.from(sessions)
.where(
genWhere([
eq(sessions.userId, this.userId),
params?.range
? genRangeWhere(params.range, sessions.createdAt, (date) => date.toDate())
: undefined,
params?.endDate
? genEndDateWhere(params.endDate, sessions.createdAt, (date) => date.toDate())
: undefined,
params?.startDate
? genStartDateWhere(params.startDate, sessions.createdAt, (date) => date.toDate())
: undefined,
]),
);
return result[0].count;
};
_rank = async (limit: number = 10): Promise<SessionRankItem[]> => {
return this.db
.select({
avatar: agents.avatar,
backgroundColor: agents.backgroundColor,
count: count(topics.id).as('count'),
id: sessions.id,
title: agents.title,
})
.from(sessions)
.where(and(eq(sessions.userId, this.userId)))
.leftJoin(topics, eq(sessions.id, topics.sessionId))
.leftJoin(agentsToSessions, eq(sessions.id, agentsToSessions.sessionId))
.leftJoin(agents, eq(agentsToSessions.agentId, agents.id))
.groupBy(sessions.id, agentsToSessions.agentId, agents.id)
.having(({ count }) => gt(count, 0))
.orderBy(desc(sql`count`))
.limit(limit);
};
// TODO: 未来将 Inbox id 入库后可以直接使用 _rank 方法
rank = async (limit: number = 10): Promise<SessionRankItem[]> => {
const inboxResult = await this.db
.select({
count: count(topics.id).as('count'),
})
.from(topics)
.where(and(eq(topics.userId, this.userId), isNull(topics.sessionId)));
const inboxCount = inboxResult[0].count;
if (!inboxCount || inboxCount === 0) return this._rank(limit);
const result = await this._rank(limit ? limit - 1 : undefined);
return [
{
avatar: DEFAULT_INBOX_AVATAR,
backgroundColor: null,
count: inboxCount,
id: INBOX_SESSION_ID,
title: 'inbox.title',
},
...result,
].sort((a, b) => b.count - a.count);
};
hasMoreThanN = async (n: number): Promise<boolean> => {
const result = await this.db
.select({ id: sessions.id })
.from(sessions)
.where(eq(sessions.userId, this.userId))
.limit(n + 1);
return result.length > n;
};
// **************** Create *************** //
create = async ({
id = idGenerator('sessions'),
type = 'agent',
session = {},
config = {},
slug,
}: {
config?: Partial<NewAgent>;
id?: string;
session?: Partial<NewSession>;
slug?: string;
type: 'agent' | 'group';
}): Promise<SessionItem> => {
return this.db.transaction(async (trx) => {
if (slug) {
const existResult = await trx.query.sessions.findFirst({
where: and(eq(sessions.slug, slug), eq(sessions.userId, this.userId)),
});
if (existResult) return existResult;
}
const newAgents = await trx
.insert(agents)
.values({
...config,
createdAt: new Date(),
id: idGenerator('agents'),
updatedAt: new Date(),
userId: this.userId,
})
.returning();
const result = await trx
.insert(sessions)
.values({
...session,
createdAt: new Date(),
id,
slug,
type,
updatedAt: new Date(),
userId: this.userId,
})
.returning();
await trx.insert(agentsToSessions).values({
agentId: newAgents[0].id,
sessionId: id,
userId: this.userId,
});
return result[0];
});
};
createInbox = async (defaultAgentConfig: DeepPartial<LobeAgentConfig>) => {
const item = await this.db.query.sessions.findFirst({
where: and(eq(sessions.userId, this.userId), eq(sessions.slug, INBOX_SESSION_ID)),
});
if (item) return;
return await this.create({
config: merge(DEFAULT_AGENT_CONFIG, defaultAgentConfig),
slug: INBOX_SESSION_ID,
type: 'agent',
});
};
batchCreate = async (newSessions: NewSession[]) => {
const sessionsToInsert = newSessions.map((s) => {
return {
...s,
id: this.genId(),
userId: this.userId,
};
});
return this.db.insert(sessions).values(sessionsToInsert);
};
duplicate = async (id: string, newTitle?: string) => {
const result = await this.findByIdOrSlug(id);
if (!result) return;
// eslint-disable-next-line @typescript-eslint/no-unused-vars,unused-imports/no-unused-vars
const { agent, clientId, ...session } = result;
const sessionId = this.genId();
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { id: _, slug: __, ...config } = agent;
return this.create({
config: config,
id: sessionId,
session: {
...session,
title: newTitle || session.title,
},
type: 'agent',
});
};
// **************** Delete *************** //
/**
* Delete a session and its associated agent data if no longer referenced.
*/
delete = async (id: string) => {
return this.db.transaction(async (trx) => {
// First get the agent IDs associated with this session
const links = await trx
.select({ agentId: agentsToSessions.agentId })
.from(agentsToSessions)
.where(and(eq(agentsToSessions.sessionId, id), eq(agentsToSessions.userId, this.userId)));
const agentIds = links.map((link) => link.agentId);
// Delete links in agentsToSessions
await trx
.delete(agentsToSessions)
.where(and(eq(agentsToSessions.sessionId, id), eq(agentsToSessions.userId, this.userId)));
// Delete the session
const result = await trx
.delete(sessions)
.where(and(eq(sessions.id, id), eq(sessions.userId, this.userId)));
// Delete orphaned agents
await this.clearOrphanAgent(agentIds, trx);
return result;
});
};
/**
* Batch delete sessions and their associated agent data if no longer referenced.
*/
batchDelete = async (ids: string[]) => {
if (ids.length === 0) return { count: 0 };
return this.db.transaction(async (trx) => {
// Get agent IDs associated with these sessions
const links = await trx
.select({ agentId: agentsToSessions.agentId })
.from(agentsToSessions)
.where(
and(inArray(agentsToSessions.sessionId, ids), eq(agentsToSessions.userId, this.userId)),
);
const agentIds = [...new Set(links.map((link) => link.agentId))];
// Delete links in agentsToSessions
await trx
.delete(agentsToSessions)
.where(
and(inArray(agentsToSessions.sessionId, ids), eq(agentsToSessions.userId, this.userId)),
);
// Delete the sessions
const result = await trx
.delete(sessions)
.where(and(inArray(sessions.id, ids), eq(sessions.userId, this.userId)));
// Delete orphaned agents
await this.clearOrphanAgent(agentIds, trx);
return result;
});
};
/**
* Delete all sessions and their associated agent data for this user.
*/
deleteAll = async () => {
return this.db.transaction(async (trx) => {
// Delete all agentsToSessions for this user
await trx.delete(agentsToSessions).where(eq(agentsToSessions.userId, this.userId));
// Delete all agents that were only used by this user's sessions
await trx.delete(agents).where(eq(agents.userId, this.userId));
// Delete all sessions for this user
return trx.delete(sessions).where(eq(sessions.userId, this.userId));
});
};
clearOrphanAgent = async (agentIds: string[], trx: any) => {
// Delete orphaned agents (those not linked to any other sessions)
for (const agentId of agentIds) {
const remaining = await trx
.select()
.from(agentsToSessions)
.where(eq(agentsToSessions.agentId, agentId))
.limit(1);
if (remaining.length === 0) {
await trx.delete(agents).where(and(eq(agents.id, agentId), eq(agents.userId, this.userId)));
}
}
};
// **************** Update *************** //
update = async (id: string, data: Partial<SessionItem>) => {
return this.db
.update(sessions)
.set(data)
.where(and(eq(sessions.id, id), eq(sessions.userId, this.userId)))
.returning();
};
updateConfig = async (sessionId: string, data: DeepPartial<AgentItem> | undefined | null) => {
if (!data || Object.keys(data).length === 0) return;
const session = await this.findByIdOrSlug(sessionId);
if (!session) return;
if (!session.agent) {
throw new Error(
'this session is not assign with agent, please contact with admin to fix this issue.',
);
}
const mergedValue = merge(session.agent, data);
return this.db
.update(agents)
.set(mergedValue)
.where(and(eq(agents.id, session.agent.id), eq(agents.userId, this.userId)));
};
// **************** Helper *************** //
private genId = () => idGenerator('sessions');
private mapSessionItem = ({
agentsToSessions,
title,
backgroundColor,
description,
avatar,
groupId,
...res
}: SessionItem & { agentsToSessions?: { agent: AgentItem }[] }): LobeAgentSession => {
// TODO: 未来这里需要更好的实现方案,目前只取第一个
const agent = agentsToSessions?.[0]?.agent;
return {
...res,
group: groupId,
meta: {
avatar: agent?.avatar ?? avatar ?? undefined,
backgroundColor: agent?.backgroundColor ?? backgroundColor ?? undefined,
description: agent?.description ?? description ?? undefined,
tags: agent?.tags ?? undefined,
title: agent?.title ?? title ?? undefined,
},
model: agent?.model,
} as any;
};
findSessionsByKeywords = async (params: {
current?: number;
keyword: string;
pageSize?: number;
}) => {
const { keyword, pageSize = 9999, current = 0 } = params;
const offset = current * pageSize;
const results = await this.db.query.agents.findMany({
limit: pageSize,
offset,
orderBy: [desc(agents.updatedAt)],
where: and(
eq(agents.userId, this.userId),
or(
like(sql`lower(${agents.title})` as unknown as Column, `%${keyword.toLowerCase()}%`),
like(
sql`lower(${agents.description})` as unknown as Column,
`%${keyword.toLowerCase()}%`,
),
),
),
with: { agentsToSessions: { columns: {}, with: { session: true } } },
});
try {
// @ts-expect-error
return results.map((item) => item.agentsToSessions[0].session);
} catch (e) {
console.error('findSessionsByKeywords error:', e);
}
return [];
};
}